NotImplementedError when training a GraphSage

I am trying to modify code in the existing repository for link prediction task. In this repository they have used RGCN for link prediction task link. Now I have tried to replace the RGCN with GraphSage but I am getting this error.

INFO:root:Input dim : 8, # Relations : 86, # Augmented relations : 109
utils/../experiments/ddi_hop3
INFO:root:No existing model found. Initializing new model..
INFO:numexpr.utils:NumExpr defaulting to 2 threads.
INFO:root:Total number of parameters: 131926
INFO:root:Starting training with full batch...
0it [00:00, ?it/s]Traceback (most recent call last):
  File "train.py", line 222, in <module>
    main(params)
  File "train.py", line 97, in main
    trainer.train()
  File "/content/managers/trainer.py", line 132, in train
    loss, auc, auc_pr, weight_norm = self.train_epoch()
  File "/content/managers/trainer.py", line 67, in train_epoch
    score_pos = self.graph_classifier(data_pos)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/model/dgl/graph_classifier.py", line 147, in forward
    g.ndata['h'] = self.gnn(g)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 201, in _forward_unimplemented
    raise NotImplementedError
NotImplementedError
0it [00:05, ?it/s]

This is the code I have tried to replace RGCN with GraphSage

class GraphClassifier(nn.Module):
    def __init__(self,params,relation2id): #, in_feats, h_feats
        super().__init__()
        self.params = params
        self.relation2id = relation2id
        self.train_rels = params.train_rels
        self.relations = params.num_rels
        self.gnn = nn.ModuleList()
        self.gnn.append(SAGEConv(self.params.feat_dim, self.params.emb_dim, aggregator_type='mean'))
        for i in range(self.params.num_gcn_layers - 1):
            self.gnn.append(SAGEConv(self.params.emb_dim, self.params.emb_dim, aggregator_type='mean'))
        self.dropout = nn.Dropout(p=0.3)
        self.relu = nn.ReLU()

        self.mp_layer1 = nn.Linear(self.params.feat_dim, self.params.emb_dim)
        self.mp_layer2 = nn.Linear(self.params.emb_dim, self.params.emb_dim)

        # self.fc_layer = nn.Linear(self.params.emb_dim * (1+self.params.num_gcn_layers), self.train_rels)
        # self.drugfeat = nn.Embedding(self.params.num_drugs, self.params.emb_dim)
        # self.drugfeat.weight.data = torch.from_numpy(np.random.normal(0, 0.1, (self.params.num_drugs, self.params.emb_dim)))
        # self.drugfeat.weight.requires_grad = True
        # self.drugfeat = self.drugfeat.cuda()
        if self.params.add_ht_emb and self.params.add_sb_emb:
            if self.params.add_feat_emb and self.params.add_transe_emb:
                self.fc_layer = nn.Linear(3 * (1+self.params.num_gcn_layers) * self.params.emb_dim + 2*self.params.emb_dim, self.train_rels)
            elif self.params.add_feat_emb :
                self.fc_layer = nn.Linear(3 * (self.params.num_gcn_layers) * self.params.emb_dim + 2*self.params.emb_dim, self.train_rels)
            else:
                self.fc_layer = nn.Linear(3 * (1+self.params.num_gcn_layers) * self.params.emb_dim, self.train_rels)
        elif self.params.add_ht_emb:
            self.fc_layer = nn.Linear(2 * (1+self.params.num_gcn_layers) * self.params.emb_dim, self.train_rels)
        else:
            self.fc_layer = nn.Linear(self.params.num_gcn_layers * self.params.emb_dim, self.train_rels)

    def drug_feat(self, emb):
        self.drugfeat = emb

    def forward(self, data): #, g, drug_idx, head_ids, tail_ids
        g = data
        g.ndata['h'] = self.gnn(g)
        g_out = mean_nodes(g, 'repr')


        g.ndata['repr'] = self.drugfeat(g.ndata['idx'])
        for i in range(self.params.num_gcn_layers):
            g.ndata['repr'] = self.gnn[i](g, g.ndata['repr'])
            g.ndata['repr'] = self.relu(g.ndata['repr'])
            g.ndata['repr'] = self.dropout(g.ndata['repr'])
        head_ids = (g.ndata['id'] == 1).nonzero().squeeze(1)
        head_embs = g.ndata['repr'][head_ids]

        tail_ids = (g.ndata['id'] == 2).nonzero().squeeze(1)
        tail_embs = g.ndata['repr'][tail_ids]
        head_feat = self.drugfeat[g.ndata['idx'][head_ids]]
        tail_feat = self.drugfeat[g.ndata['idx'][tail_ids]]
        # drug_feat = self.drugfeat[drug_idx]
        # print(drug_feat, drug_feat.shape)
        if self.params.add_feat_emb:
                fuse_feat1 = self.mp_layer2( self.relu( self.dropout( self.mp_layer1(
                                head_feat #torch.cat([head_feat, tail_feat], dim = 1)
                            ))))
                fuse_feat2 = self.mp_layer2( self.relu( self.dropout( self.mp_layer1(
                                tail_feat #torch.cat([head_feat, tail_feat], dim = 1)
                            ))))
                fuse_feat = torch.cat([fuse_feat1, fuse_feat2], dim = 1)
        if self.params.add_ht_emb and self.params.add_sb_emb:
            if self.params.add_feat_emb and self.params.add_transe_emb:
                g_rep = torch.cat([g_out.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
                                   head_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
                                   tail_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
                                   fuse_feat.view(-1, 2*self.params.emb_dim)
                                   ], dim=1)
            elif self.params.add_feat_emb:
                g_rep = torch.cat([g_out.view(-1, (self.params.num_gcn_layers) * self.params.emb_dim),
                                   head_embs.view(-1, (self.params.num_gcn_layers) * self.params.emb_dim),
                                   tail_embs.view(-1, (self.params.num_gcn_layers) * self.params.emb_dim),
                                   fuse_feat.view(-1, 2*self.params.emb_dim)
                                   ], dim=1)
            else:
                g_rep = torch.cat([g_out.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
                                   head_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
                                   tail_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
                                   #fuse_feat.view(-1, 2*self.params.emb_dim)
                                   ], dim=1)
            
        elif self.params.add_ht_emb:
            g_rep = torch.cat([
                                head_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
                                tail_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim)
                               ], dim=1)
        else:
            g_rep = g_out.view(-1, self.params.num_gcn_layers * self.params.emb_dim)
        #print(g_rep.shape, self.params.add_ht_emb, self.params.add_sb_emb)
        output = self.fc_layer(F.dropout(g_rep, p =0.3))
        # print(head_ids.detach().cpu().numpy(), tail_ids.detach().cpu().numpy())
        return output

Please let me know if their is any implementation error in graph sage that is causing this error?.

From your log, the error stack is also from PyTorch so the problem is probably some misuse of PyTorch NN module.

I suspect the problem is the use of self.gnn which is initialized as an ModuleList.

self.gnn = nn.ModuleList()
self.gnn.append(SAGEConv(self.params.feat_dim, self.params.emb_dim, aggregator_type='mean'))
for i in range(self.params.num_gcn_layers - 1):
    self.gnn.append(SAGEConv(self.params.emb_dim, self.params.emb_dim, aggregator_type='mean'))

But is used as an NN module

g.ndata['h'] = self.gnn(g)

Besides, the SAGEConv layer has two arguments, g and node_feature so the invocation also doesn’t look right.