Hello guys, sorry if this is a rookie question
I have so far tried GCN to classify graphs, this is the code I’m using which is mostly based on the tutorial in the dgl website :
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim1 , hidden_dim2 , hidden_dim3 , n_classes):
super(Classifier, self).__init__()
self.conv1 = GraphConv(in_dim, hidden_dim1)
self.conv2 = GraphConv(hidden_dim1, hidden_dim2 )
self.conv3 = GraphConv(hidden_dim2 , hidden_dim3 )
self.classify = nn.Linear(hidden_dim3 , n_classes)
def forward(self, g , device):
h = g.ndata['x']
h = torch.relu(self.conv1(g, h))
h = torch.relu(self.conv2(g, h))
h = torch.relu(self.conv3(g, h))
g.ndata['h'] = h
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
where g.ndata[‘x’] is basically the features for nodes, now how can i change this code to use GIN instead of GCN? i was reading the GinCONV code, and it seems to not have input output dims like GCN does (again sorry if its a rookie question), so how can i use GIN instead in my code?
i couldn’t find any tutorial anywhere that shows how to implement GIN using dgl