Thanks for pointing out the documentation, I did not notice this particularity. It now works, but I am having trouble with custom layers. I am trying to build a HeteroGraphConv with custom layers, but again, my dimensions are not matching.
Here is my customer Test_layer:
class Test_layer(nn.Module): #Replaces SAGEConv layer
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
def __init__(self,
in_feats,
out_feats):
super(Test_layer, self).__init__()
self._in_src_feats, self._in_dst_feats = dgl.utils.expand_as_pair(in_feats)
self._out_feats = out_feats
# NN through which the initial 'h' state passes
self.fc_self = nn.Linear(self._in_dst_feats, out_feats)
# NN through which the neighborhood messages passes
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats)
self.reset_parameters()
def forward(self, graph, feat):
h_self = feat
#get input features
graph.srcdata['h'] = h_self
#update nodes values
graph.update_all(
fn.copy_src('h', 'm'),
fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
print(h_self.shape)
print(h_neigh.shape)
# result is output of NN for h_self + output of NN for h_neigh
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
self.reset_parameters()
return rst
Then, I call the HeteroGraph with some user and item features.
layer = dglnn.HeteroGraphConv({'buys':Test_layer(5,100),
'bought-by':Test_layer(3,100)},
aggregate='sum')
item_feat = torch.randn((g.number_of_nodes('item'), 3))
user_feat = torch.randn((g.number_of_nodes('user'), 5))
features = {'item' : item_feat, 'user' : user_feat }
out = layer(g, features)
I get the following error.
RuntimeError: The size of tensor a (1892) must match the size of tensor b (2492) at non-singleton dimension 0
Which makes sense, since my h_self is of shape torch.Size([1892, 5]) and my h_neigh of size torch.Size([2492, 5]).
How can I arrange this so that my h_neigh becomes the right shape, i.e. 1892 as well? I though that my reduce function fn_mean would do so, but it appears I am not making it work properly.
Again, thanks for your answers, newbie here trying to understand better heterographs in DGL!