Terrible node classification accuracy in multiclass classification

I am using DGL to train a GAT model to classify nodes in graph structured data. Everything I have tried has resulted in terrible classification accuracy results.

After many days of trying to train a GAT model that classifies nodes into 1 of 32 classes, I have forced the node features to be linearly separable i.e. nodes in class 0 have a feature vector with a 1 at position 0 followed by 31 0’s, class 1 has a feature vector of [0 1 0 .... ].

The dataset is completely class balanced. I have forced there to be 64,000 nodes in a single graph with 2000 nodes of each class. The structure of this graph is determined from my dataset but on average each node has 3 edges.

Even with linearly separable node features for each class, I cannot achieve more that 5% classification accuracy with my GAT implementation in DGL.

Please could someone tell me what is wrong with my code:

relabelled_g = nx.convert_node_labels_to_integers(G)
dgl_g    = dgl.from_networkx(relabelled_g)
##force even distribution of classes, set feature vector of each node to zeors with 1 at {:class}
n_samples_per_class = 1 + (n_nodes // n_classes)
labels   = np.ndarray((n_nodes,), dtype=np.int)
features = np.ndarray((n_nodes, in_dim), dtype=np.double)
for i in tqdm.tqdm(range(n_nodes)):
    x = i // n_samples_per_class
    labels[i] = x
    d = np.zeros((in_dim,), dtype=np.float)
    d[x] = 1.0
    features[i, :] = d

dgl_g.ndata['label'] = torch.from_numpy(labels)
dgl_g.ndata['feat']  = torch.from_numpy(features).float()
"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""


from tensorflow.keras import layers
import torch.nn.functional as F

import torch
import torch.nn as nn
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax, GATConv


class GAT(nn.Module):
    def __init__(self,
                 g,
                 num_layers,
                 in_dim,
                 num_hidden,
                 num_classes,
                 heads,
                 activation,
                 feat_drop,
                 attn_drop,
                 negative_slope,
                 residual):
        super(GAT, self).__init__()
        self.g = g
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        # input projection (no residual)
        self.gat_layers.append(GATConv(
            in_dim, num_hidden, heads[0],
            feat_drop, attn_drop, negative_slope, False, self.activation))
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
            self.gat_layers.append(GATConv(
                num_hidden * heads[l-1], num_hidden, heads[l],
                feat_drop, attn_drop, negative_slope, residual, self.activation))
        # output projection
        self.gat_layers.append(GATConv(
            num_hidden * heads[-2], num_classes, heads[-1],
            feat_drop, attn_drop, negative_slope, residual, None))

    def forward(self, inputs):
        h = inputs
        for l in range(self.num_layers):
            h = self.gat_layers[l](self.g, h).flatten(1)
        # output projection
        logits = self.gat_layers[-1](self.g, h).mean(1)
        return logits
def accuracy(logits, labels):
    correct = torch.sum(logits == labels)
    return correct.item() * 1.0 / len(labels)
negative_slope = 0.5
attention_drop = 0.5
in_drop        = 0.5
num_hidden     = 4
num_heads      = 3
num_layers     = 2
num_out_heads  = 2
epochs         = 100
residual       = False

features  = dgl_g.ndata['feat']
labels    = dgl_g.ndata['label'].reshape(-1, 1).float()

heads = ([num_heads] * num_layers) + [num_out_heads]
net = GAT(dgl_g,
                num_layers,
                features.size()[1],
                num_hidden,
                1,
                heads,
                F.relu,
                in_drop,
                attention_drop,
                negative_slope,
                residual)

#loss_fcn = torch.nn.CrossEntropyLoss()
loss_fcn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=4e-2)

# main loop
dur = []
for epoch in range(epochs):
    if epoch >= 3:
        t0 = time.time()

    optimizer.zero_grad()

    logits = net(features)
    loss = loss_fcn(logits, labels)

    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    train_acc = accuracy(logits, labels)

    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | Train Acc {:.4f}".format(
        epoch, loss.item(), np.mean(dur), train_acc))

In training the accuracy initially increases from 0.01 to 0.05 after 100 epochs but then decreases after 500 epochs back to 0.01.

Thanks.

  1. Have you tried a vanilla GNN before? Say GCN?
  2. What graphs are you dealing with? What’s the task and scenario?