Hello guys, sorry if this is a rookie question
I have so far tried GCN to classify graphs, this is the code I’m using which is mostly based on the tutorial in the dgl website :
class Classifier(nn.Module): def __init__(self, in_dim, hidden_dim1 , hidden_dim2 , hidden_dim3 , n_classes): super(Classifier, self).__init__() self.conv1 = GraphConv(in_dim, hidden_dim1) self.conv2 = GraphConv(hidden_dim1, hidden_dim2 ) self.conv3 = GraphConv(hidden_dim2 , hidden_dim3 ) self.classify = nn.Linear(hidden_dim3 , n_classes) def forward(self, g , device): h = g.ndata['x'] h = torch.relu(self.conv1(g, h)) h = torch.relu(self.conv2(g, h)) h = torch.relu(self.conv3(g, h)) g.ndata['h'] = h hg = dgl.mean_nodes(g, 'h') return self.classify(hg)
where g.ndata[‘x’] is basically the features for nodes, now how can i change this code to use GIN instead of GCN? i was reading the GinCONV code, and it seems to not have input output dims like GCN does (again sorry if its a rookie question), so how can i use GIN instead in my code?
i couldn’t find any tutorial anywhere that shows how to implement GIN using dgl