- You can safely replace
self.rgcn.forward(blocks, h)
byself.rgcn(blocks, h)
. This is becauseself.rgcn
inheritstorch.nn.Module
, which invokesforward(self, ...)
in__call__(self, ...)
.
I have defined the
self.predictor
as an instance of theScorePredictor
, but I’ve got the feeling that this is not going to work for heterographs (with multiply node types). Do I have to iterate through the node types here?
- For a heterogeneous graph with multiple node types, you can check whether
x
is a dictionary mapping node types to the corresponding features. If so, this should work. For link prediction,ScorePredictor
iterates over all canonical edge types and implicitly iterates over the associated node types. Meanwhile, it’s possible that you don’t want to perform link prediction for all edge types but only for a subset of edge types.
And to what does the
x
refer to inedge_subgraph.ndata['x']
? Does it have to be changed toh_dst
which is used in theCustomHeteroGraphConv
?
- For link prediction, you need to consider the updated representations of both source nodes and destination nodes, hence
x
is the updated feature for all nodes from the GNN.
Should the
forward()
function of theBaseRGCN
then additionally takeh
as an input and not construct it in the function? (I commented out the lines that were used before):
- It really depends on how you want to perform modeling, i.e. whether you want to pre-process it before invoking message passing.