hi everyone, i wonder how adapt my heterogeneous graph at graphSAGE.
my code snippet is as follows:
import dgl
import numpy as np
import torch as th
from dgl.nn import SAGEConv
u = [0, 1, 0, 0, 1]
v = [0, 1, 2, 3, 2]
g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
u_feat = th.rand(4, 10)
v_feat = th.rand(4, 10)
#conv_user = SAGEConv(10, 2, 'mean')
#conv_item = SAGEConv(10, 2, 'mean')
conv_user = SAGEConv((10,10), 2, 'mean')
conv_item = SAGEConv((10, 10), 2, 'mean')
preds1 = conv_user(g, (u_feat, v_feat))
preds2 = conv_item(g, (v_feat, u_feat))
print("user embedding", preds1)
print("item embedding", preds2)
but I find the following error as result:
---------------------------------------------------------------------------
DGLError Traceback (most recent call last)
<ipython-input-20-bb94bddd4168> in <cell line: 16>()
14 conv_item = SAGEConv((10, 10), 2, 'mean')
15
---> 16 preds1 = conv_user(g, (u_feat, v_feat))
17 preds2 = conv_item(g, (v_feat, u_feat))
18
3 frames
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
/usr/local/lib/python3.9/dist-packages/dgl/nn/pytorch/conv/sageconv.py in forward(self, graph, feat, edge_weight)
222 # Message Passing
223 if self._aggre_type == 'mean':
--> 224 graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
225 graph.update_all(msg_fn, fn.mean('m', 'neigh'))
226 h_neigh = graph.dstdata['neigh']
/usr/local/lib/python3.9/dist-packages/dgl/view.py in __setitem__(self, key, val)
97 "please pass a tensor directly"
98 )
---> 99 self._graph._set_n_repr(self._ntid, self._nodes, {key: val})
100
101 def __delitem__(self, key):
/usr/local/lib/python3.9/dist-packages/dgl/heterograph.py in _set_n_repr(self, ntid, u, data)
4030 nfeats = F.shape(val)[0]
4031 if nfeats != num_nodes:
-> 4032 raise DGLError('Expect number of features to match number of nodes (len(u)).'
4033 ' Got %d and %d instead.' % (nfeats, num_nodes))
4034 if F.context(val) != self.device:
DGLError: Expect number of features to match number of nodes (len(u)). Got 4 and 2 instead.
who can solve this problem?!
thanks .