Hi, I am trying to use RGCN for node classification with one type of node and two types of edges. All nodes have pre-defined feature with 200 dimension. I have followed the tutorial here, but have got some error as follows.
Traceback (most recent call last):
File "run_classifier_rgcn.py", line 363, in <module>
main()
File "run_classifier_rgcn.py", line 317, in main
train(train_dataset, model, mlb, g, feats, edge_type, edge_norm, args.batch_sz, args.num_epochs, criterion, device,
File "run_classifier_rgcn.py", line 165, in train
output = model(text, G, feats, edge_type, edge_norm)
File "/home/xdwang/dgl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/lustre03/project/6001103/xdwang/MeSH_Indexing_RGCN/model.py", line 404, in forward
label_feature = self.rgcn(g, g_node_feature, edge_type, edge_norm)
File "/home/xdwang/dgl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/lustre03/project/6001103/xdwang/MeSH_Indexing_RGCN/model.py", line 364, in forward
h = layer(g, h, r, norm)
File "/home/xdwang/dgl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/xdwang/dgl/lib/python3.8/site-packages/dgl/nn/pytorch/conv/relgraphconv.py", line 282, in forward
g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
File "/home/xdwang/dgl/lib/python3.8/site-packages/dgl/heterograph.py", line 4499, in update_all
ndata = core.message_passing(g, message_func, reduce_func, apply_node_func)
File "/home/xdwang/dgl/lib/python3.8/site-packages/dgl/core.py", line 291, in message_passing
msgdata = invoke_edge_udf(g, ALL, g.canonical_etypes[0], mfunc, orig_eid=orig_eid)
File "/home/xdwang/dgl/lib/python3.8/site-packages/dgl/core.py", line 82, in invoke_edge_udf
return func(ebatch)
File "/home/xdwang/dgl/lib/python3.8/site-packages/dgl/nn/pytorch/conv/relgraphconv.py", line 209, in basis_message_func
sub_msg = th.matmul(src, w)
RuntimeError: mat1 dim 1 must match mat2 dim 0
I have followed the tutorial and the only different is that I used my own dataset. I have created a heterogenous graph using dgl.heterograph()
. My graph looks like Graph(num_nodes={'mesh': 29368}, num_edges={('mesh', 'cooccurrence', 'mesh'): 861, ('mesh', 'hierarchy', 'mesh'): 39920}, metagraph=[('mesh', 'mesh', 'cooccurrence'), ('mesh', 'mesh', 'hierarchy')])
and the node features look likeG.ndata['feat'].shape --->torch.Size([29368, 200])
Thanks!