I’m using the MiniGCDataset
model.eval()
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
# adjust dType to long maybe test probs_Y = torch.softmax(model(test_bg), 1) ?
test_Y = torch.tensor(test_Y, dtype=torch.long)
prediction = model(test_bg)
loss = loss_function(prediction, test_Y)
test_losses.append(loss.detach().item())
model.train()
–> this code works fine with GCN but not with GATConv layers, it throws an “Expect number of features to match number of nodes” exception…
I don’t understand why
This is the code I used: --> oriented at this example for node classification https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py
class GATClassifier(nn.Module):
def init(self, num_layers, in_dim, hidden_dim, num_heads, n_classes, activation, feat_drop,
attn_drop, negative_slope, residual):
super(GATClassifier, self).init()
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GATConv(in_dim, hidden_dim, num_heads,
feat_drop, attn_drop, negative_slope, False, self.activation))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GATConv(hidden_dim * num_heads, hidden_dim, num_heads,
feat_drop, attn_drop, negative_slope, residual, self.activation))
# output projection
self.gat_layers.append(GATConv(hidden_dim * num_heads, n_classes, num_heads,
feat_drop, attn_drop, negative_slope, residual, None))
# for the read out from the hidden features of the graph; linear transformation
self.classify = nn.Linear(hidden_dim * num_heads, n_classes)
def forward(self, g):
# For undirected graphs, in_degree is the same as
# out_degree.
h = g.in_degrees().view(-1, 1).float()
for i in range(self.num_layers): # enumerate
h = self.gat_layers[i](g, h).flatten(1)
# update all nodes
g.ndata['h'] = h
# Readout process: graph is aggregated to a 1 by x vector
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
def collate(samples):
graphs, labels = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
return batched_graph, torch.tensor(labels, dtype=torch.long)
and for training the model:
num_epochs = 16
for epoch in range(num_epochs):
epoch_loss = 0
test_loss = 0
for iter, (bg, label) in enumerate(data_loader):
prediction = model(bg)
loss = loss_function(prediction, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
logging.info(‘Epoch {}, loss {:.4f}’.format(epoch, epoch_loss))
epoch_losses.append(epoch_loss)
and for evaluating the model:
for iter, (test_bg, test_label) in enumerate(test_data_loader):
model.eval()
prediction = model(test_bg)
loss = loss_function(prediction, test_label)
test_loss += loss.detach().item()
test_loss /= (iter + 1)
test_losses.append(test_loss)
model.train()
Happy for any input