Link Prediction Error: RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>

Hi, everyone, I have read document 5.1 tutorial for link prediction and visited the thread here. I already modified my HeteroDotProductPredictor per suggestion from graph.ndata['h'] = h to graph.ndata['h'] = h[node_type] , since I only have 1 node type and multiple edge types (13). However, the exact error still persisted. Could you help me debug the problem?

Code snippet:

class HeteroDotProductPredictor(torch.nn.Module):
    def forward(self, graph, h, etype):
        with graph.local_scope():
            print(h)
            graph.ndata['h'] = h['face']  # node type "face"
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

class Model(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

class RGCN(torch.nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = HeteroGraphConv({
            rel: GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = HeteroGraphConv({
            rel: GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

train, val, test = .. # process custom data
hetero_graph = train[0] # get a graph for testing the training pipeline
k = 5
model = Model(hetero_graph.ndata["x"].shape[1], hetero_graph.num_nodes("face"), 1, hetero_graph.etypes)
node_features = hetero_graph.nodes["face"].data["x"].float()
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    negative_graph = construct_negative_graph(hetero_graph, k, ("face", "main", "face"))
    pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ("face", "main", "face"))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

Error trace:

RuntimeError                              Traceback (most recent call last)
Cell In[21], line 8
      6 for epoch in range(10):
      7     negative_graph = construct_negative_graph(hetero_graph, k, ("face", "main", "face"))
----> 8     pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ("face", "main", "face"))
      9     loss = compute_loss(pos_score, neg_score)
     10     opt.zero_grad()

File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Cell In[5], line 7, in Model.forward(self, g, neg_g, x, etype)
      6 def forward(self, g, neg_g, x, etype):
----> 7     h = self.sage(g, x)
      8     return self.pred(g, h, etype), self.pred(neg_g, h, etype)

File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Cell In[6], line 14, in RGCN.forward(self, graph, inputs)
     12 def forward(self, graph, inputs):
     13     # inputs are features of nodes
---> 14     h = self.conv1(graph, inputs)
     15     h = {k: F.relu(v) for k, v in h.items()}
     16     h = self.conv2(graph, h)

File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\dgl\nn\pytorch\hetero.py:208, in HeteroGraphConv.forward(self, g, inputs, mod_args, mod_kwargs)
    206 for stype, etype, dtype in g.canonical_etypes:
    207     rel_graph = g[stype, etype, dtype]
--> 208     if stype not in inputs:
    209         continue
    210     dstdata = self._get_module((stype, etype, dtype))(
    211         rel_graph,
    212         (inputs[stype], inputs[dtype]),
    213         *mod_args.get(etype, ()),
    214         **mod_kwargs.get(etype, {})
    215     )

File ~\AppData\Local\miniconda3\envs\pmis\lib\site-packages\torch\_tensor.py:1061, in Tensor.__contains__(self, element)
   1055 if isinstance(
   1056     element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool)
   1057 ):
   1058     # type hint doesn't understand the __contains__ result array
   1059     return (element == self).any().item()  # type: ignore[union-attr]
-> 1061 raise RuntimeError(
   1062     f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}."
   1063 )

RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.

Update 1: here is an example of the input dgl graph:

Graph(num_nodes={'face': 32},
      num_edges={('face', 'cir', 'face'): 4, ('face', 'con', 'face'): 0, ('face', 'cyn', 'face'): 0, ('face', 'deg', 'face'): 2, ('face', 'dia', 'face'): 3, ('face', 'fit', 'face'): 0, ('face', 'main', 'face'): 66, ('face', 'pm', 'face'): 0, ('face', 'ro', 'face'): 3, ('face', 'rou', 'face'): 0, ('face', 'stra', 'face'): 0, ('face', 'sym', 'face'): 0},
      metagraph=[('face', 'face', 'cir'), ('face', 'face', 'con'), ('face', 'face', 'cyn'), ('face', 'face', 'deg'), ('face', 'face', 'dia'), ('face', 'face', 'fit'), ('face', 'face', 'main'), ('face', 'face', 'pm'), ('face', 'face', 'ro'), ('face', 'face', 'rou'), ('face', 'face', 'stra'), ('face', 'face', 'sym')])

Thank you in advance.

node_features = hetero_graph.nodes[“face”].data[“x”].float()
node_features = {“face”: node_features}

node_features is supposed to be dict[str, Tensor]

Whole code I used:

import torch
import dgl
import dgl.graphbolt as gb
import numpy as np
import torch.multiprocessing as mp
import dgl.nn as dglnn
import dgl.function as fn
import torch.nn.functional as F
from dgl.nn import GraphConv, HeteroGraphConv


class HeteroDotProductPredictor(torch.nn.Module):
    def forward(self, graph, h, etype):
        with graph.local_scope():
            print(h)
            graph.ndata['h'] = h['face']  # node type "face"
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

class Model(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

class RGCN(torch.nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = HeteroGraphConv({
            rel: GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = HeteroGraphConv({
            rel: GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

#train, val, test = .. # process custom data
#hetero_graph = train[0] # get a graph for testing the training pipeline
hetero_graph = dgl.heterograph({("face", "main", "face"): ([0, 1, 2, 3], [1, 2, 3, 0]),
                                ("face", "main2", "face"): ([0, 1, 2, 2], [1, 2, 3, 0]),
                                ("face", "main3", "face"): ([0, 1, 2, 3], [1, 2, 3, 2])})
hetero_graph.nodes["face"].data["x"] = torch.randn(4, 3)
k = 5
model = Model(hetero_graph.ndata["x"].shape[1], hetero_graph.num_nodes("face"), 1, hetero_graph.etypes)
node_features = hetero_graph.nodes["face"].data["x"].float()
node_features = {"face": node_features}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    negative_graph = construct_negative_graph(hetero_graph, k, ("face", "main", "face"))
    pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ("face", "main", "face"))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

Hi, thank you. It works for me. I have a follow-up question: I swapped RGCN with GATv2conv, but it does not work as expected with the attention heads. I used your full code and changed the RGCN class to GNN:

class GNN(torch.nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        self.conv1 = HeteroGraphConv({
            rel: GATv2Conv(in_feats, hid_feats, num_heads=4)
            for rel in rel_names}, aggregate='mean')
        self.conv2 = HeteroGraphConv({
            rel: GATv2Conv(hid_feats*4, out_feats, num_heads=1)
            for rel in rel_names}, aggregate='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)   
        print(h["face"].shape)       
        #h["face"] = h["face"].reshape(-1, 80)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

The line print(h["face"].shape) gives me [32,4,20], which is [num_nodes, att_heads, out_dim]. But the next layer conv yields the error:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x20 and 80x1)

It seems to me h output from conv1 was concatenated to shape [32*4, 20]. If I reshape it to h["face"]=h["face"].reshape(-1, att_head*hid_feats) #here [32, 80], then the code runs fine. However, I don’t think this is a normal behavior right? Could you help me detect the issue?

pls refer to https://github.com/dmlc/dgl/tree/master/examples/pytorch/gatv2 for debug