I am using DGL to train a GAT model to classify nodes in graph structured data. Everything I have tried has resulted in terrible classification accuracy results.
After many days of trying to train a GAT model that classifies nodes into 1 of 32 classes, I have forced the node features to be linearly separable i.e. nodes in class 0 have a feature vector with a 1 at position 0 followed by 31 0’s, class 1 has a feature vector of [0 1 0 .... ]
.
The dataset is completely class balanced. I have forced there to be 64,000 nodes in a single graph with 2000 nodes of each class. The structure of this graph is determined from my dataset but on average each node has 3 edges.
Even with linearly separable node features for each class, I cannot achieve more that 5% classification accuracy with my GAT implementation in DGL.
Please could someone tell me what is wrong with my code:
relabelled_g = nx.convert_node_labels_to_integers(G)
dgl_g = dgl.from_networkx(relabelled_g)
##force even distribution of classes, set feature vector of each node to zeors with 1 at {:class}
n_samples_per_class = 1 + (n_nodes // n_classes)
labels = np.ndarray((n_nodes,), dtype=np.int)
features = np.ndarray((n_nodes, in_dim), dtype=np.double)
for i in tqdm.tqdm(range(n_nodes)):
x = i // n_samples_per_class
labels[i] = x
d = np.zeros((in_dim,), dtype=np.float)
d[x] = 1.0
features[i, :] = d
dgl_g.ndata['label'] = torch.from_numpy(labels)
dgl_g.ndata['feat'] = torch.from_numpy(features).float()
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""
from tensorflow.keras import layers
import torch.nn.functional as F
import torch
import torch.nn as nn
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax, GATConv
class GAT(nn.Module):
def __init__(self,
g,
num_layers,
in_dim,
num_hidden,
num_classes,
heads,
activation,
feat_drop,
attn_drop,
negative_slope,
residual):
super(GAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
# input projection (no residual)
self.gat_layers.append(GATConv(
in_dim, num_hidden, heads[0],
feat_drop, attn_drop, negative_slope, False, self.activation))
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(GATConv(
num_hidden * heads[l-1], num_hidden, heads[l],
feat_drop, attn_drop, negative_slope, residual, self.activation))
# output projection
self.gat_layers.append(GATConv(
num_hidden * heads[-2], num_classes, heads[-1],
feat_drop, attn_drop, negative_slope, residual, None))
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](self.g, h).flatten(1)
# output projection
logits = self.gat_layers[-1](self.g, h).mean(1)
return logits
def accuracy(logits, labels):
correct = torch.sum(logits == labels)
return correct.item() * 1.0 / len(labels)
negative_slope = 0.5
attention_drop = 0.5
in_drop = 0.5
num_hidden = 4
num_heads = 3
num_layers = 2
num_out_heads = 2
epochs = 100
residual = False
features = dgl_g.ndata['feat']
labels = dgl_g.ndata['label'].reshape(-1, 1).float()
heads = ([num_heads] * num_layers) + [num_out_heads]
net = GAT(dgl_g,
num_layers,
features.size()[1],
num_hidden,
1,
heads,
F.relu,
in_drop,
attention_drop,
negative_slope,
residual)
#loss_fcn = torch.nn.CrossEntropyLoss()
loss_fcn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=4e-2)
# main loop
dur = []
for epoch in range(epochs):
if epoch >= 3:
t0 = time.time()
optimizer.zero_grad()
logits = net(features)
loss = loss_fcn(logits, labels)
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
train_acc = accuracy(logits, labels)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | Train Acc {:.4f}".format(
epoch, loss.item(), np.mean(dur), train_acc))
In training the accuracy initially increases from 0.01 to 0.05 after 100 epochs but then decreases after 500 epochs back to 0.01.
Thanks.