In the context of user guide 5.3, you can change HeteroDotProductPredictor
to the code snippet below:
class MLP(nn.Module):
def __init__(self, out_feats):
self.layer = nn.Sequential(
nn.Linear(out_feats, out_feats),
nn.ReLU(),
nn.Linear(out_feats, out_feats)
)
def forward(self, x):
return self.layer(x)
class HeteroDotProductPredictor(nn.Module):
def __init__(self, out_feats):
self.etype_project = {'link1': MLP(out_feats), 'link2': MLP(out_feats), 'link3': MLP(out_feats)}
def forward(self, graph, h, etype):
# h contains the node representations for each node type computed from
# the GNN defined in the previous section (Section 5.1).
with graph.local_scope():
graph.ndata['h'] = self.etype_project[etype](h)
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']