How to do graph classification using GIN?

Hello guys, sorry if this is a rookie question

I have so far tried GCN to classify graphs, this is the code I’m using which is mostly based on the tutorial in the dgl website :

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim1 , hidden_dim2 , hidden_dim3 , n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim1)
        self.conv2 = GraphConv(hidden_dim1, hidden_dim2 )
        self.conv3 = GraphConv(hidden_dim2 ,  hidden_dim3 )
        self.classify = nn.Linear(hidden_dim3 , n_classes)

    def forward(self, g , device):
 
        h = g.ndata['x']

        h = torch.relu(self.conv1(g, h))
        h = torch.relu(self.conv2(g, h))
        h = torch.relu(self.conv3(g, h))
        g.ndata['h'] = h
        hg = dgl.mean_nodes(g, 'h')
 
        return  self.classify(hg) 

where g.ndata[‘x’] is basically the features for nodes, now how can i change this code to use GIN instead of GCN? i was reading the GinCONV code, and it seems to not have input output dims like GCN does (again sorry if its a rookie question), so how can i use GIN instead in my code?

i couldn’t find any tutorial anywhere that shows how to implement GIN using dgl

You can find the source code for GINConv here. The message passing in GINConv does not involve size change for node features. It is after the completion of message passing and a residual update that you can optionally apply a callable function to the node features, say a linear layer or an MLP.

1 Like