@mufeili I have followed the implementation basically from the rgcn linkpredict example.
‘model’ is an instance of the LinkPredict
whose forward()
method is the one from BaseRGCN
:
class LinkPredict(nn.Module):
def __init__(self,
heterograph: Heterograph,
g,
h_dim,
out_dim,
device,
num_hidden_layers=1,
dropout=0,
use_self_loop=True,
reg_param=0,
):
super(LinkPredict, self).__init__()
...
self.rgcn: BaseRGCN = BaseRGCN(
g=g,
h_dim=h_dim,
out_dim=out_dim,
device=device,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
use_self_loop=use_self_loop
)
...
def forward(self, g):
return self.rgcn.forward(g)
BaseRGCN
is defined in my case as follows:
class BaseRGCN(nn.Module):
def __init__(self,
g,
h_dim,
out_dim,
device: th.device,
num_hidden_layers=1,
dropout=0,
use_self_loop=True):
super(BaseRGCN, self).__init__()
...
self.layers: nn.ModuleList = nn.ModuleList()
self.embed_layer = RelGraphEmbed(g=self.g, embed_size=self.h_dim, device=self.device)
# Embedding to Input layer
self.layers.append(RelGraphConvLayer(
in_feat=self.h_dim, out_feat=self.h_dim, rel_names=self.rel_names, device=device, g=self.g, #layer_number=0, last_layer=False,
activation=F.relu, self_loop=self.use_self_loop, dropout=self.dropout)
)
# hidden to hidden layer
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayer(
in_feat=self.h_dim, out_feat=self.h_dim, rel_names=self.rel_names, device=device, g=self.g, #layer_number=1, last_layer=False,
activation=F.relu, self_loop=self.use_self_loop, dropout=self.dropout, )
)
# hidden to output layer
self.layers.append(RelGraphConvLayer(
in_feat=self.h_dim, out_feat=self.out_dim, rel_names=self.rel_names, device=device, g=self.g, #layer_number=2, last_layer=True,
activation=None,self_loop=self.use_self_loop,)
)
def forward(self, batched_graph):
batched_graph = batched_graph.local_var()
hs = {}
hs = self.embed_layer(batched_graph)
for idx, layer in enumerate(self.layers):
hs = layer.forward(batched_graph, inputs=hs, layer_number=idx)
return hs
Which in turn leads again to the forward()
function of each layer from the RelGraphConv
class RelGraphConvLayer(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : list[str]
Relation names.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
weight : bool, optional
True if a linear layer is applied after message passing. Default: True
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
device,
g,
*,
weight=True,
bias=True,
activation=None,
self_loop=True,
dropout=0.0):
super(RelGraphConvLayer, self).__init__()
# added self.g
self.g = g
# changed from regular dglnn.HeteroGraphConv to CustomHeteroGraphConv
self.conv = dglnn.HeteroGraphConv({
rel[1]: CustomHeteroGraphConv(g=self.g, in_feats=self.in_feat, out_feats=self.out_feat)
for rel in rel_names
})
...
def forward(self, g, inputs):
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input graph.
inputs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
dict[str, torch.Tensor]
New node features for each node type.
"""
g = g.local_var()
if self.use_weight:
weight = self.basis() if self.use_basis else self.weight
wdict = {self.rel_names[i] : {'weight' : w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
else:
wdict = {}
if g.is_block:
inputs_src = inputs
inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
inputs_src = inputs_dst = inputs
hs = self.conv(g, inputs, mod_kwargs=wdict)
def _apply(ntype, h):
if self.self_loop:
h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
return self.dropout(h)
return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
And here now we come to the point where CustomHeteroGraphConv
is used in the self.conv
.
So do I understand correctly, that when
edge_predictions = model(edge_subgraph, blocks, input_features)
is called, then the forward()
method from CustomHeteroGraphConv
is called, and which should have input features g
(the batched graph) and h
? What is h
in this case?