I am trying to build a GNN model on a heterogenous graph. The goal is to build a link prediction model. It looks like the chapters in the DGL website do not cover many topics, and I am stuck

This is what my graph looks like:
Graph(num_nodes={‘customer’: 8813, ‘product’: 157466},
num_edges={(‘customer’, ‘browsed’, ‘product’): 860771, (‘customer’, ‘purchased’, ‘product’): 68367},
metagraph=[(‘customer’, ‘product’, ‘browsed’), (‘customer’, ‘product’, ‘purchased’)])

The customer nodes have 932 features, and the product nodes, 5641 features.

My code looks like this:

class RGCN(nn.Module):

def __init__(self, in_feats_dict, hid_feats_customer, hid_feats_product, out_feats, rel_names):
    super().__init__()

    # Create a dictionary to store the RGCN layers for each node type
    self.conv1_dict = nn.ModuleDict({
        node_type: dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats_dict[node_type].shape[1], 
            hid_feats_customer if node_type == 'customer' else hid_feats_product).double()
            for rel in rel_names
        }, aggregate='sum')
        for node_type in in_feats_dict
    })

    self.conv2_dict = nn.ModuleDict({
        node_type: dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats_customer if node_type == 'customer' else hid_feats_product, out_feats)
            for rel in rel_names
        }, aggregate='sum')
        for node_type in in_feats_dict
    })

    
def forward(self, graph, inputs):
    # inputs are features of nodes
    h = {node_type: self.conv1_dict[node_type](graph, inputs)
         for node_type in inputs.keys()}
    h = {k: F.relu(v) for k, v in h.items()}
    h = {node_type: self.conv2_dict[node_type](graph, h[node_type])
         for node_type in h.keys()}
    return h

class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
# h contains the node representations for each node type computed from
# the GNN defined in the previous section (Section 5.1).
with graph.local_scope():
graph.ndata[‘h’] = h
graph.apply_edges(fn.u_dot_v(‘h’, ‘h’, ‘score’), etype=etype)
return graph.edges[etype].data[‘score’]

def construct_negative_graph(graph, k, etype):
utype, _, vtype = etype
src, dst = graph.edges(etype=etype)
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
return dgl.heterograph(
{etype: (neg_src, neg_dst)},
num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

class Model(nn.Module):
def init(self, in_feats_dict, hid_feats_customer, hid_feats_product, out_feats, rel_names):
super().init()
self.sage = RGCN(in_feats_dict, hid_feats_customer, hid_feats_product, out_feats, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, neg_g, inputs, etype):
h_customer = self.sage(g, inputs)
h_product = self.sage(g, inputs)
return self.pred(g, h_customer, etype), self.pred(neg_g, h_customer, etype)

def compute_loss(pos_score, neg_score):
# Margin loss
n_edges = pos_score.shape[0]
return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

cust_feats = cust_prod_graph.nodes[‘customer’].data[‘cust_features’]
prod_feats = cust_prod_graph.nodes[‘product’].data[‘prod_features’]
node_features = {‘customer’: cust_feats, ‘product’: prod_feats}

model = Model(node_features, 20,10,16, cust_prod_graph.etypes)

opt = torch.optim.Adam(model.parameters())

negative_graph= construct_negative_graph(cust_prod_graph, 1000, (‘customer’, ‘purchased’, ‘product’))

pos_score, neg_score = model(cust_prod_graph, negative_graph,
node_features,
(‘customer’, ‘purchased’, ‘product’))

For the code above, I am getting the following errors:

RuntimeError Traceback (most recent call last)
Cell In[316], line 5
1 opt = torch.optim.Adam(model.parameters())
3 negative_graph= construct_negative_graph(cust_prod_graph, 1000, (‘customer’, ‘purchased’, ‘product’))
----> 5 pos_score, neg_score = model(cust_prod_graph, negative_graph,
6 node_features,
7 (‘customer’, ‘purchased’, ‘product’))

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don’t have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[309], line 77, in Model.forward(self, g, neg_g, inputs, etype)
76 def forward(self, g, neg_g, inputs, etype):
—> 77 h_customer = self.sage(g, inputs)
78 h_product = self.sage(g, inputs)
79 return self.pred(g, h_customer, etype), self.pred(neg_g, h_customer, etype)

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don’t have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[309], line 45, in RGCN.forward(self, graph, inputs)
43 def forward(self, graph, inputs):
44 # inputs are features of nodes
—> 45 h = {node_type: self.conv1_dict[node_type](graph, inputs)
46 for node_type in inputs.keys()}
47 h = {k: F.relu(v) for k, v in h.items()}
48 h = {node_type: self.conv2_dict[node_type](graph, h[node_type])
49 for node_type in h.keys()}

Cell In[309], line 45, in (.0)
43 def forward(self, graph, inputs):
44 # inputs are features of nodes
—> 45 h = {node_type: self.conv1_dict[node_type](graph, inputs)
46 for node_type in inputs.keys()}
47 h = {k: F.relu(v) for k, v in h.items()}
48 h = {node_type: self.conv2_dict[node_type](graph, h[node_type])
49 for node_type in h.keys()}

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don’t have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/lib/python3.10/site-packages/dgl/nn/pytorch/hetero.py:210, in HeteroGraphConv.forward(self, g, inputs, mod_args, mod_kwargs)
208 if stype not in inputs:
209 continue
→ 210 dstdata = self._get_module((stype, etype, dtype))(
211 rel_graph,
212 (inputs[stype], inputs[dtype]),
213 *mod_args.get(etype, ()),
214 **mod_kwargs.get(etype, {})
215 )
216 outputs[dtype].append(dstdata)
217 rsts = {}

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don’t have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/lib/python3.10/site-packages/dgl/nn/pytorch/conv/graphconv.py:450, in GraphConv.forward(self, graph, feat, weight, edge_weight)
447 if self._in_feats > self._out_feats:
448 # mult W first to reduce the feature size for aggregation.
449 if weight is not None:
→ 450 feat_src = th.matmul(feat_src, weight)
451 graph.srcdata[“h”] = feat_src
452 graph.update_all(aggregate_fn, fn.sum(msg=“m”, out=“h”))

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8813x932 and 5641x10)

This is just one of the problems I have encountered. It would be great if I could get some help.

Thanks!

Hello, I noticed there seems to be some confusion about the concept of heterogeneous graphs in your work. I’ve made some adjustments to the RGCN model. In a heterograph, you don’t need to construct a conv for each node_type.

class RGCN(nn.Module):
    def __init__(self, in_feats_dict, hid_feats_customer, hid_feats_product, out_feats, rel_names):
        super().__init__()

        # Create a dictionary to store the RGCN layers for each node type
        #! Modify
        self.conv1_dict = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats_dict[src_node_type].shape[1], 
            hid_feats_customer if src_node_type == 'customer' else hid_feats_product)
            for src_node_type, rel, _ in rel_names
        }, aggregate='sum')

        self.conv2_dict = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats_customer if src_node_type == 'customer' else hid_feats_product, out_feats)
            for src_node_type, rel, _ in rel_names
        }, aggregate='sum')

        
    def forward(self, graph, inputs):
        # inputs are features of nodes
        #! Modify
        h = self.conv1_dict(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2_dict(graph, h)
        return h

A little tip: use rel_names = cust_prod_graph.canonical_etypes for defining rel_names.


After modifying the model, it’s important to note that there are issues with the forward in HeteroDotProductPredictor.

If you have any questions, feel free to please ask me.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.