GIN on custom graph dataset for graph classification is very unstable with high loss and low accuracy compared to GCN?

If i am not mistaken, based on their paper the GIN with 1 layer MLP and with both poolings set to mean, should be very similar to GCN.

i am using the GIN code provided by dgl on my own dataset : https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin

but the problem is that when i compare the GIN with mean/mean pooling and 1 layer MLP with my own GCN that i implemented using dgl’s GraphConv, it is so much worse!

for example this is GIN after 25 epoch :

train set - average loss: 2.4517, accuracy: 6.7742%
valid set - average loss: 2.4659, accuracy: 8.0460%
epoch-25: 

and here’s with sum neighbors and sum graph pooling after 25 epoch :

valid set - average loss: 291.1245, accuracy: 16%
train set - average loss: 101.9377, accuracy: 16%

and this is my GCN :

train set - average loss: 1.5589, accuracy: 50.8065%
valid set - average loss: 1.5502, accuracy: 43.6782%
epoch-25:

everything is the same, the dataset, learning rate, optimization parameters, this is basically my GCN :

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim1 , n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim1)
        self.classify = nn.Linear(hidden_dim1 , n_classes)

def forward(self, g , h ):

    h = torch.relu(self.conv1(g, h))
    g.ndata['h'] = h
    hg = dgl.mean_nodes(g, 'h')
    return  self.classify(hg) 

and my GIN code is basically using model = GIN(…), the same as the provided example in dgl. and the way i calculate loss and accuracy is exactly the same in both, loss function the same too CrossEntropyLoss, no dropout, my code is basically very similar to main.py of the GIN example.

is this normal? GIN just seems so unstable, the accuracy goes up and down very much, when i switch to sum sum the loss goes off the roof, everything seems to unstable and not good, is this normal? changing the mean/mean to anything else doesn’t help much either…

This is the code of net evaluation for both GCN and GIN, identical to provided examples in dgl :

def eval_net(args, net, dataloader, criterion):
    net.eval()

total = 0
total_loss = 0
total_correct = 0

for data in dataloader:
    graphs, labels = data

    feat = graphs.in_degrees().view(-1, 1).float().to(args.device)

    graphs = graphs.to(args.device)
    labels = labels.to(args.device)
    total += len(labels)
    outputs = net(graphs, feat)

    _, predicted = torch.max(outputs.data, 1)

    total_correct += (predicted == labels.data).sum().item()
    loss = criterion(outputs, labels)
    total_loss += loss.item() * len(labels)

loss, acc = 1.0*total_loss / total, 1.0*total_correct / total

net.train()

return loss, acc

and my train function which again is identical in both of them :

def train(args, net, trainloader, optimizer, criterion, epoch):
    net.train()

    running_loss = 0
    total_iters = len(trainloader)
    bar = tqdm(range(total_iters), unit='batch',  file=sys.stdout)

    for pos, (graphs, labels) in zip(bar, trainloader):

        labels = labels.to(args.device)

        feat = graphs.in_degrees().view(-1, 1).float().to(args.device)

        graphs = graphs.to(args.device)
        outputs = net(graphs, feat)

        loss = criterion(outputs, labels)
        running_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # report
        bar.set_description('epoch-{}'.format(epoch))
    bar.close()
    # the final batch will be aligned
    running_loss = running_loss / total_iters

    return running_loss

higher epochs doesnt help either, GIN with 100 epochs is still worse than GCN with 20 epoch! why?!

  1. By 1-layer MLP, is it a single linear layer or two linear layers with an activation function in the middle?
  2. You may need to tune the learning rate independently for GCN and GIN.
  3. You are using 1 layer for GCN. How many GIN layers are you using for GIN?
  1. num_mlp_layers = 1 and num_layers = 1, so both 1

  2. i tried many many learning rates, nothing worked, still much worse than GCN, tried setting the dropout to 0, no luck still worse than GCN

  3. only one layer

also if its needed i am willing to share part of the dataset if you wanted so to check whats happening, because this is not making any sense, i even tried on a different dataset, removed the isolated nodes to make sure they are not causing it, still no luck. the simple way of testing it is also just using the degree as the only feature and checking how it performs

My gut feeling is that one layer GIN and GCN is probably not very expressive, since a node is not able to gather information from context that is topologically farther away. 1-WL test, which is the upper bound of GIN, also requires multiple message passing steps to work.

That being said, you can probably try on one dataset that we have tested and see if the same behavior also exists. If so, then it’s probably a problem in your model configuration. Otherwise it is probably a problem of your dataset (either preprocessing or others).

But as i said, using the GCN model of the dgl directly (graphconv) instead of using GIN, gives a lot better results than the GIN with num_mlp_layers = 1 and num_layers = 1 and same hyper parameters in my dataset and has no problem, so this is not a problem with my dataset, if there was a problem then my GCN that i implemented using GraphConv would also have bad results, because in theory both of them should be the same… and my model is literally the same GIN provided by DGL, and my GCN is provided in my question. so i dont think its a problem with my model/dataset.

Here’s another example that i tried on a different dataset and using the GIN with sum/sum after 60-70 epochs:

train set - average loss: 35.4424, accuracy: 65.6886%
valid set - average loss: 36.2543, accuracy: 65.4666%

train set - average loss: 0.1639, accuracy: 96.7440%
valid set - average loss: 0.2084, accuracy: 96.1398%

train set - average loss: 17.7134, accuracy: 67.1898%
valid set - average loss: 18.1681, accuracy: 66.8313%

and this is with learning rate set to 0.0001!! but my GCN is above 98% at the same epochs with no problem and not accuracy going up and down so much, and again my codes are basically the same for both of them, the only difference is that in GIN my model is model = GIN, and in GCN and construct a 1 layer GCN as explained in the tutorials.

and with mean/mean with 1 layer, the loss doesn’t go up and down so much, but the accuracy is still much worse than the GCN, for example 90% vs 99%.

again, i am using the GIN.py provided in the dgl repository and just pass the parameters to it, so not sure if i should change it or not?

By saying “problem of your dataset” I don’t mean that your dataset is wrong; maybe it’s just that GCN simply works better than GIN in your particular dataset. After all, GCN has a different formulation from GIN because GCN adds self-loops and aggregates neighboring features with a weighted sum determined by the Laplacian of the graph, instead of a simple sum + concatenation like GIN does.

What is this dataset? In particular I’m curious whether the dataset is your own or a dataset mentioned in GIN’s paper or a dataset tested in our example.

After all, GCN has a different formulation from GIN because GCN adds self-loops and aggregates neighboring features with a weighted sum determined by the Laplacian of the graph, instead of a simple sum + concatenation like GIN does

But the problem is that based on their GIN paper, GIN with mean/mean is very similar to GCN if not identical, thats how they compare it with GCN, you can see it in their charts in the paper, and in theory they say sum sum should be better than mean mean, not 10 times worse! because in my case, not only it is worse than mean/mean, it performs very poorly as i show in above posts.

What is this dataset? In particular I’m curious whether the dataset is your own or a dataset mentioned in GIN’s paper or a dataset tested in our example.

Its a custom dataset with added self loops, and dual edges (if there is a edge from a to b, then there is a edge from b to a), every node has a 1500 dimension input feature. and the number of nodes of graphs varies from 5-6 nodes to 1000-2000 nodes. also i should mention that the features are sparse, meaning that many of the values in these 1500 dimension vectors are 0, and all of them are either 0 or above 0 ( also tried normilizing the input features but didn’t help)

Let’s directly compare the computation in the Classifier you defined and GIN to understand the differences.

Let G=(V, E) be a graph where V is the set of nodes and E is the set of edges. Let h_{v}^{(l)}\in\mathbb{R}^{1\times d_l} be the representation of node v\in V after the l-th GNN (GCN/GIN) layer. Let h_v^{(0)} be the input feature for node v.

Classifier

The classifier proceeds as follows:

Updates the representation of each node v\in V

h_{v}^{(1)} = \sigma\left(b^{(1)}+\sum_{u\in\mathcal{N}(v)}\frac{1}{\sqrt{\tilde{d}_{out, u} \tilde{d}_{in, v}}}h_u^{(0)}W^{(1)}\right)

where:

  • \sigma is the ReLU activation function
  • b^{(1)}\in\mathbb{R}^{1\times d_1} is a learnable bias
  • \mathcal{N}(v) = \{u\mid (u, v)\in E\} is the predecessors of node v
  • \tilde{d}_{out, u}=\max(d_{out,u}, 1), where d_{out,u} is the out-degree of u, i.e. the number of successors of u.
  • \tilde{d}_{in, v}=\max(d_{in,v}, 1), where d_{in,v} is the in-degree of v, i.e. the number of predecessors of v.
  • W^{(1)}\in\mathbb{R}^{d_0\times d_1} is a learnable weight

Computes the representation of the graph G

h_G = \frac{1}{|V|}\sum_{v\in V}h_v^{(1)}

Performs the prediction for the graph

h_{G}W^{pred}+b^{pred}

where:

  • W^{pred}\in\mathbb{R}^{d_1\times C} is a learnable weight, C is the output size (e.g. the number of classes)
  • b^{pred}\in\mathbb{R}^{1\times C} is a learnable bias

GIN

I assume you initialized the model with GIN(num_layers=1, num_mlp_layers=1, graph_pooling_type='mean', neighbor_pooling_type='mean').

In that case, you are actually skipping all GNN layers according to the code and only performing the following steps.

Computes the representation of the graph G

h_G = \frac{1}{|V|}\sum_{v\in V}h_v^{(0)}

Performs the prediction for the graph

h_{G}W^{pred}+b^{pred}

where:

  • W^{pred}\in\mathbb{R}^{d_1\times C} is a learnable weight, C is the output size (e.g. the number of classes)
  • b^{pred}\in\mathbb{R}^{1\times C} is a learnable bias
1 Like