I am trying to modify code in the existing repository for link prediction task. In this repository they have used RGCN for link prediction task link. Now I have tried to replace the RGCN with GraphSage but I am getting this error.
INFO:root:Input dim : 8, # Relations : 86, # Augmented relations : 109
utils/../experiments/ddi_hop3
INFO:root:No existing model found. Initializing new model..
INFO:numexpr.utils:NumExpr defaulting to 2 threads.
INFO:root:Total number of parameters: 131926
INFO:root:Starting training with full batch...
0it [00:00, ?it/s]Traceback (most recent call last):
File "train.py", line 222, in <module>
main(params)
File "train.py", line 97, in main
trainer.train()
File "/content/managers/trainer.py", line 132, in train
loss, auc, auc_pr, weight_norm = self.train_epoch()
File "/content/managers/trainer.py", line 67, in train_epoch
score_pos = self.graph_classifier(data_pos)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/content/model/dgl/graph_classifier.py", line 147, in forward
g.ndata['h'] = self.gnn(g)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 201, in _forward_unimplemented
raise NotImplementedError
NotImplementedError
0it [00:05, ?it/s]
This is the code I have tried to replace RGCN with GraphSage
class GraphClassifier(nn.Module):
def __init__(self,params,relation2id): #, in_feats, h_feats
super().__init__()
self.params = params
self.relation2id = relation2id
self.train_rels = params.train_rels
self.relations = params.num_rels
self.gnn = nn.ModuleList()
self.gnn.append(SAGEConv(self.params.feat_dim, self.params.emb_dim, aggregator_type='mean'))
for i in range(self.params.num_gcn_layers - 1):
self.gnn.append(SAGEConv(self.params.emb_dim, self.params.emb_dim, aggregator_type='mean'))
self.dropout = nn.Dropout(p=0.3)
self.relu = nn.ReLU()
self.mp_layer1 = nn.Linear(self.params.feat_dim, self.params.emb_dim)
self.mp_layer2 = nn.Linear(self.params.emb_dim, self.params.emb_dim)
# self.fc_layer = nn.Linear(self.params.emb_dim * (1+self.params.num_gcn_layers), self.train_rels)
# self.drugfeat = nn.Embedding(self.params.num_drugs, self.params.emb_dim)
# self.drugfeat.weight.data = torch.from_numpy(np.random.normal(0, 0.1, (self.params.num_drugs, self.params.emb_dim)))
# self.drugfeat.weight.requires_grad = True
# self.drugfeat = self.drugfeat.cuda()
if self.params.add_ht_emb and self.params.add_sb_emb:
if self.params.add_feat_emb and self.params.add_transe_emb:
self.fc_layer = nn.Linear(3 * (1+self.params.num_gcn_layers) * self.params.emb_dim + 2*self.params.emb_dim, self.train_rels)
elif self.params.add_feat_emb :
self.fc_layer = nn.Linear(3 * (self.params.num_gcn_layers) * self.params.emb_dim + 2*self.params.emb_dim, self.train_rels)
else:
self.fc_layer = nn.Linear(3 * (1+self.params.num_gcn_layers) * self.params.emb_dim, self.train_rels)
elif self.params.add_ht_emb:
self.fc_layer = nn.Linear(2 * (1+self.params.num_gcn_layers) * self.params.emb_dim, self.train_rels)
else:
self.fc_layer = nn.Linear(self.params.num_gcn_layers * self.params.emb_dim, self.train_rels)
def drug_feat(self, emb):
self.drugfeat = emb
def forward(self, data): #, g, drug_idx, head_ids, tail_ids
g = data
g.ndata['h'] = self.gnn(g)
g_out = mean_nodes(g, 'repr')
g.ndata['repr'] = self.drugfeat(g.ndata['idx'])
for i in range(self.params.num_gcn_layers):
g.ndata['repr'] = self.gnn[i](g, g.ndata['repr'])
g.ndata['repr'] = self.relu(g.ndata['repr'])
g.ndata['repr'] = self.dropout(g.ndata['repr'])
head_ids = (g.ndata['id'] == 1).nonzero().squeeze(1)
head_embs = g.ndata['repr'][head_ids]
tail_ids = (g.ndata['id'] == 2).nonzero().squeeze(1)
tail_embs = g.ndata['repr'][tail_ids]
head_feat = self.drugfeat[g.ndata['idx'][head_ids]]
tail_feat = self.drugfeat[g.ndata['idx'][tail_ids]]
# drug_feat = self.drugfeat[drug_idx]
# print(drug_feat, drug_feat.shape)
if self.params.add_feat_emb:
fuse_feat1 = self.mp_layer2( self.relu( self.dropout( self.mp_layer1(
head_feat #torch.cat([head_feat, tail_feat], dim = 1)
))))
fuse_feat2 = self.mp_layer2( self.relu( self.dropout( self.mp_layer1(
tail_feat #torch.cat([head_feat, tail_feat], dim = 1)
))))
fuse_feat = torch.cat([fuse_feat1, fuse_feat2], dim = 1)
if self.params.add_ht_emb and self.params.add_sb_emb:
if self.params.add_feat_emb and self.params.add_transe_emb:
g_rep = torch.cat([g_out.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
head_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
tail_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
fuse_feat.view(-1, 2*self.params.emb_dim)
], dim=1)
elif self.params.add_feat_emb:
g_rep = torch.cat([g_out.view(-1, (self.params.num_gcn_layers) * self.params.emb_dim),
head_embs.view(-1, (self.params.num_gcn_layers) * self.params.emb_dim),
tail_embs.view(-1, (self.params.num_gcn_layers) * self.params.emb_dim),
fuse_feat.view(-1, 2*self.params.emb_dim)
], dim=1)
else:
g_rep = torch.cat([g_out.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
head_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
tail_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
#fuse_feat.view(-1, 2*self.params.emb_dim)
], dim=1)
elif self.params.add_ht_emb:
g_rep = torch.cat([
head_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim),
tail_embs.view(-1, (1+self.params.num_gcn_layers) * self.params.emb_dim)
], dim=1)
else:
g_rep = g_out.view(-1, self.params.num_gcn_layers * self.params.emb_dim)
#print(g_rep.shape, self.params.add_ht_emb, self.params.add_sb_emb)
output = self.fc_layer(F.dropout(g_rep, p =0.3))
# print(head_ids.detach().cpu().numpy(), tail_ids.detach().cpu().numpy())
return output
Please let me know if their is any implementation error in graph sage that is causing this error?.