Hi,
I am trying to apply HeteroGraphConv with SAGEConv module.
my graph has multiple node types, each type has a different population size. I am receiving a dimension error related to the number of nodes, here is an example code:
edges = [(0,0),(0,1),(1,1),(2,1)]
hg = dgl.heterograph({('type1','link','type2'):edges})
conv = dglnn.HeteroGraphConv({
'link' : dglnn.SAGEConv(6,2,'mean')},
aggregate='stack')
node_feat = {'type1': th.rand(3,6).float(), 'type2': th.rand(2,6).float()}
h1 = conv(hg,node_feat)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/usr/local/bin/kernel-launchers/python/scripts/launch_ipykernel.py in <module>
7
8 node_feat = {'type1': th.rand(3,6).float(), 'type2': th.rand(2,6).float()}
----> 9 h1 = conv(hg,node_feat)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/lib/python3.6/site-packages/dgl/nn/pytorch/hetero.py in forward(self, g, inputs, mod_args, mod_kwargs)
172 inputs[stype],
173 *mod_args.get(etype, ()),
--> 174 **mod_kwargs.get(etype, {}))
175 outputs[dtype].append(dstdata)
176 rsts = {}
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/opt/conda/lib/python3.6/site-packages/dgl/nn/pytorch/conv/sageconv.py in forward(self, graph, feat)
215 rst = self.fc_neigh(h_neigh)
216 else:
--> 217 rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
218 # activation
219 if self.activation is not None:
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0
when trying to run with GraphConv it works fine.
Any idea on why is it not working for SAGEConv?
Thanks!