Hi All
I’m interested in implementing a jumping-knowledge layer and incorporating this into GNN learning. I’ve implemented the concat
and lstm
aggregations as described in the paper, but am concerned that I’ve done something incorrectly, as my model is not learning properly.
Assume I’m using a GAT network, and we have the following linear layer to compute attentions, and LSTM layer – this is based on code from the Pytorch-Geometric example:
if aggregation == 'cat':
linear = Linear(num_heads*num_hidden*num_layers, num_classes, bias=True)
elif aggregation == 'lstm':
lstm = LSTM(input_size=num_hidden * num_heads,
hidden_size=num_hidden,
num_layers = 2,
batch_first=True,
bidirectional=True)
attn = Linear(2*num_hidden, 1)
linear = Linear(num_heads*num_hidden, num_classes, bias=True)
def forward(self, g=None, inputs=None, return_alpha=False, **kwds):
"""
Forward pass of network.
Parameters:
- - - - -
g: DGL Graph
the graph
inputs: tensor
node features
Returns:
- - - - -
logits: tensor
output layer
"""
h = inputs
xs = []
for l in range(self.num_layers):
# hidden embeddings for each ```GATConv``` layer
h = layers[l](g,h).flatten(1)
# append embeddings to list
xs.append(h.unsqueeze(-1))
# LSTM aggregator
if self.aggregation == 'lstm':
# input to lstm
# xs shape will be (nodes x seq_length x features)
xs = torch.cat(xs, dim=-1).transpose(1,2)
### compute attentions
# concatenated embeddings of forward / backward LSTM pass
alpha,_ = lstm(xs)
alpha = attn(alpha).squeeze(-1)
alpha = torch.softmax(alpha, dim=-1)
# compute final embedding of JK layer
# output will be of size (nodes x classes)
h = (xs * alpha.unsqueeze(-1)).sum(1)
h = linear(h)
# CONCAT aggregator
else:
h = torch.cat(xs, dim=1).squeeze()
h = linear(h)
# apply sigmoid activation to jumping-knowledge output
logits = torch.sigmoid(h)
if return_alpha:
return logits, alpha
else:
return logits
Does this seem reasonable? When trained and validated on the same data, a vanilla GAT network learns just fine with good training and validation accuracy and loss, while neither the cat
nor lstm
networks learn (high loss and low accuracy on both training and validation sets), even after training for many epochs. I feel this must be an error in my code.
Any help is much appreciated.
k