I am trying to apply the GCN2Conv layers with HetroConv module of dgl, but every time I find the "TypeError: forward() missing 1 required positional argument: ‘feat_0’ ". Can you help to fix? The code which I have used is shown below:
import dgl
import torch
from dgl.nn import HeteroGraphConv, GCN2Conv
edges1 = (torch.tensor([0,0,1,1,2]), torch.tensor([2,3,4,4,5]))
edges2 = (torch.tensor([0,1,2,3,4]), torch.tensor([2, 3, 4, 5, 5]))
# random edge masks
mask1 = (torch.rand(5,1)>0.5).float()
mask2 = (torch.rand(5,1)>0.5).float()
#print(type(mask1)) #tenosr
#print(mask1.shape)
#print(mask2)
graph_data = {
('user', 'follows', 'user'): edges1,
('user', 'plays', 'game'): edges2
}
# convert the graph data to the hetrogeneous graph
g = dgl.heterograph(graph_data)
#print("g", g)
h1 = {
'user': torch.randn((g.num_nodes("user"), 5)),
'game': torch.randn((g.num_nodes("game"), 5))
}
lin1 = torch.nn.Linear(5, 6)
conv1 = HeteroGraphConv({
'follows': GCN2Conv(6, layer=1, alpha=0.5, project_initial_features=True),
'plays': GCN2Conv(6, layer=1, alpha=0.5, project_initial_features=True),
}, aggregate="sum")
conv2 = HeteroGraphConv({
'follows': GCN2Conv(6, layer=1, alpha=0.5, project_initial_features=True),
'plays': GCN2Conv(6, layer=1, alpha=0.5, project_initial_features=True),
}, aggregate="sum")
# pas masks as edge_weight here
inp1 = {key: lin1(x) for key, x in h1.items()}
res = conv1(g, inp1, inp1, mod_kwargs={
'follows':{'edge_weight':mask1},
'plays':{'edge_weight': mask2}
}
)
res = conv2(g, res, inp1, mod_kwargs={
'follows':{'edge_weight':mask1},
'plays':{'edge_weight': mask2}
}
)
My error is shown below:
Traceback (most recent call last):
File "/DGL/GCNII/example1.py", line 49, in <module>
res = conv1(g, inp1, inp1, mod_kwargs={
File "/dgl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/.conda/envs/dgl/lib/python3.8/site-packages/dgl/nn/pytorch/hetero.py", line 185, in forward
dstdata = self.mods[etype](
File "/.conda/envs/dgl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'feat_0'