Link prediction on all edge types

I’m following the tutorial for Link prediction, but it only works for one type of edge. How should I approach, or what needs to be changed, in order to train/evaluate on all types of edges.

Is it also possible to obtain metrics for each type of edge?

Thank you.

you could extend HeteroDotProductPredictor, construct_negative_graph and related function calls. just iterate on all edge types. Below is an example for your reference.

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h):
        with graph.local_scope():
            graph.ndata['h'] = h
            for etype in graph.canonical_etypes:
                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

Thank you for the reply. I managed to make it work. I will attach the code in case others have issues with this.

def construct_negative_graph(graph, k):
    edge_dict = dict()
    num_nodes_dict = dict()
    for etype in graph.canonical_etypes:
        utype, _, vtype = etype
        src, dst = graph.edges(etype=etype)
        neg_src = src.repeat_interleave(k)
        neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
        edge_dict.update(
            {etype: (neg_src, neg_dst)}
        )
        num_nodes_dict.update({ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

    return dgl.heterograph(edge_dict, num_nodes_dict)

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, node_type):
        with graph.local_scope():
            graph.ndata['h'] = h[node_type]
            for etype in graph.canonical_etypes:
                graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()

    def forward(self, g, neg_g, x, node_type):
        h = self.sage(g, x)
        return self.pred(g, h, node_type), self.pred(neg_g, h, node_type

Small note, i use node_type as an argument because I have only one node type in my hetero-graph.

2 Likes

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.