Thank you for the reply. I managed to make it work. I will attach the code in case others have issues with this.
def construct_negative_graph(graph, k):
edge_dict = dict()
num_nodes_dict = dict()
for etype in graph.canonical_etypes:
utype, _, vtype = etype
src, dst = graph.edges(etype=etype)
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
edge_dict.update(
{etype: (neg_src, neg_dst)}
)
num_nodes_dict.update({ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})
return dgl.heterograph(edge_dict, num_nodes_dict)
class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, node_type):
with graph.local_scope():
graph.ndata['h'] = h[node_type]
for etype in graph.canonical_etypes:
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, neg_g, x, node_type):
h = self.sage(g, x)
return self.pred(g, h, node_type), self.pred(neg_g, h, node_type
Small note, i use node_type
as an argument because I have only one node type in my hetero-graph.