Hi, everyone, I have read document 5.1 tutorial for link prediction and visited the thread here. I already modified my HeteroDotProductPredictor
per suggestion from graph.ndata['h'] = h
to graph.ndata['h'] = h[node_type]
, since I only have 1 node type and multiple edge types (13). However, the exact error still persisted. Could you help me debug the problem?
Code snippet:
class HeteroDotProductPredictor(torch.nn.Module):
def forward(self, graph, h, etype):
with graph.local_scope():
print(h)
graph.ndata['h'] = h['face'] # node type "face"
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
class Model(torch.nn.Module):
def __init__(self, in_features, hidden_features, out_features, rel_names):
super().__init__()
self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, neg_g, x, etype):
h = self.sage(g, x)
return self.pred(g, h, etype), self.pred(neg_g, h, etype)
class RGCN(torch.nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, rel_names):
super().__init__()
self.conv1 = HeteroGraphConv({
rel: GraphConv(in_feats, hid_feats)
for rel in rel_names}, aggregate='sum')
self.conv2 = HeteroGraphConv({
rel: GraphConv(hid_feats, out_feats)
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs):
# inputs are features of nodes
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
return h
def compute_loss(pos_score, neg_score):
# Margin loss
n_edges = pos_score.shape[0]
return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()
train, val, test = .. # process custom data
hetero_graph = train[0] # get a graph for testing the training pipeline
k = 5
model = Model(hetero_graph.ndata["x"].shape[1], hetero_graph.num_nodes("face"), 1, hetero_graph.etypes)
node_features = hetero_graph.nodes["face"].data["x"].float()
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
negative_graph = construct_negative_graph(hetero_graph, k, ("face", "main", "face"))
pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ("face", "main", "face"))
loss = compute_loss(pos_score, neg_score)
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
Error trace:
RuntimeError Traceback (most recent call last)
Cell In[21], line 8
6 for epoch in range(10):
7 negative_graph = construct_negative_graph(hetero_graph, k, ("face", "main", "face"))
----> 8 pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ("face", "main", "face"))
9 loss = compute_loss(pos_score, neg_score)
10 opt.zero_grad()
File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
Cell In[5], line 7, in Model.forward(self, g, neg_g, x, etype)
6 def forward(self, g, neg_g, x, etype):
----> 7 h = self.sage(g, x)
8 return self.pred(g, h, etype), self.pred(neg_g, h, etype)
File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
Cell In[6], line 14, in RGCN.forward(self, graph, inputs)
12 def forward(self, graph, inputs):
13 # inputs are features of nodes
---> 14 h = self.conv1(graph, inputs)
15 h = {k: F.relu(v) for k, v in h.items()}
16 h = self.conv2(graph, h)
File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\dgl\nn\pytorch\hetero.py:208, in HeteroGraphConv.forward(self, g, inputs, mod_args, mod_kwargs)
206 for stype, etype, dtype in g.canonical_etypes:
207 rel_graph = g[stype, etype, dtype]
--> 208 if stype not in inputs:
209 continue
210 dstdata = self._get_module((stype, etype, dtype))(
211 rel_graph,
212 (inputs[stype], inputs[dtype]),
213 *mod_args.get(etype, ()),
214 **mod_kwargs.get(etype, {})
215 )
File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\_tensor.py:1061, in Tensor.__contains__(self, element)
1055 if isinstance(
1056 element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool)
1057 ):
1058 # type hint doesn't understand the __contains__ result array
1059 return (element == self).any().item() # type: ignore[union-attr]
-> 1061 raise RuntimeError(
1062 f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}."
1063 )
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.
Update 1: here is an example of the input dgl graph:
Graph(num_nodes={'face': 32},
num_edges={('face', 'cir', 'face'): 4, ('face', 'con', 'face'): 0, ('face', 'cyn', 'face'): 0, ('face', 'deg', 'face'): 2, ('face', 'dia', 'face'): 3, ('face', 'fit', 'face'): 0, ('face', 'main', 'face'): 66, ('face', 'pm', 'face'): 0, ('face', 'ro', 'face'): 3, ('face', 'rou', 'face'): 0, ('face', 'stra', 'face'): 0, ('face', 'sym', 'face'): 0},
metagraph=[('face', 'face', 'cir'), ('face', 'face', 'con'), ('face', 'face', 'cyn'), ('face', 'face', 'deg'), ('face', 'face', 'dia'), ('face', 'face', 'fit'), ('face', 'face', 'main'), ('face', 'face', 'pm'), ('face', 'face', 'ro'), ('face', 'face', 'rou'), ('face', 'face', 'stra'), ('face', 'face', 'sym')])
Thank you in advance.