This is what my graph looks like:
Graph(num_nodes={‘customer’: 8813, ‘product’: 157466},
num_edges={(‘customer’, ‘browsed’, ‘product’): 860771, (‘customer’, ‘purchased’, ‘product’): 68367},
metagraph=[(‘customer’, ‘product’, ‘browsed’), (‘customer’, ‘product’, ‘purchased’)])
The customer nodes have 932 features, and the product nodes, 5641 features.
My code looks like this:
class RGCN(nn.Module):
def __init__(self, in_feats_dict, hid_feats_customer, hid_feats_product, out_feats, rel_names):
super().__init__()
# Create a dictionary to store the RGCN layers for each node type
self.conv1_dict = nn.ModuleDict({
node_type: dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats_dict[node_type].shape[1],
hid_feats_customer if node_type == 'customer' else hid_feats_product).double()
for rel in rel_names
}, aggregate='sum')
for node_type in in_feats_dict
})
self.conv2_dict = nn.ModuleDict({
node_type: dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats_customer if node_type == 'customer' else hid_feats_product, out_feats)
for rel in rel_names
}, aggregate='sum')
for node_type in in_feats_dict
})
def forward(self, graph, inputs):
# inputs are features of nodes
h = {node_type: self.conv1_dict[node_type](graph, inputs)
for node_type in inputs.keys()}
h = {k: F.relu(v) for k, v in h.items()}
h = {node_type: self.conv2_dict[node_type](graph, h[node_type])
for node_type in h.keys()}
return h
class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
# h contains the node representations for each node type computed from
# the GNN defined in the previous section (Section 5.1).
with graph.local_scope():
graph.ndata[‘h’] = h
graph.apply_edges(fn.u_dot_v(‘h’, ‘h’, ‘score’), etype=etype)
return graph.edges[etype].data[‘score’]
def construct_negative_graph(graph, k, etype):
utype, _, vtype = etype
src, dst = graph.edges(etype=etype)
neg_src = src.repeat_interleave(k)
neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
return dgl.heterograph(
{etype: (neg_src, neg_dst)},
num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})
class Model(nn.Module):
def init(self, in_feats_dict, hid_feats_customer, hid_feats_product, out_feats, rel_names):
super().init()
self.sage = RGCN(in_feats_dict, hid_feats_customer, hid_feats_product, out_feats, rel_names)
self.pred = HeteroDotProductPredictor()
def forward(self, g, neg_g, inputs, etype):
h_customer = self.sage(g, inputs)
h_product = self.sage(g, inputs)
return self.pred(g, h_customer, etype), self.pred(neg_g, h_customer, etype)
def compute_loss(pos_score, neg_score):
# Margin loss
n_edges = pos_score.shape[0]
return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()
cust_feats = cust_prod_graph.nodes[‘customer’].data[‘cust_features’]
prod_feats = cust_prod_graph.nodes[‘product’].data[‘prod_features’]
node_features = {‘customer’: cust_feats, ‘product’: prod_feats}
model = Model(node_features, 20,10,16, cust_prod_graph.etypes)
opt = torch.optim.Adam(model.parameters())
negative_graph= construct_negative_graph(cust_prod_graph, 1000, (‘customer’, ‘purchased’, ‘product’))
pos_score, neg_score = model(cust_prod_graph, negative_graph,
node_features,
(‘customer’, ‘purchased’, ‘product’))
For the code above, I am getting the following errors:
RuntimeError Traceback (most recent call last)
Cell In[316], line 5
1 opt = torch.optim.Adam(model.parameters())
3 negative_graph= construct_negative_graph(cust_prod_graph, 1000, (‘customer’, ‘purchased’, ‘product’))
----> 5 pos_score, neg_score = model(cust_prod_graph, negative_graph,
6 node_features,
7 (‘customer’, ‘purchased’, ‘product’))
File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don’t have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
Cell In[309], line 77, in Model.forward(self, g, neg_g, inputs, etype)
76 def forward(self, g, neg_g, inputs, etype):
—> 77 h_customer = self.sage(g, inputs)
78 h_product = self.sage(g, inputs)
79 return self.pred(g, h_customer, etype), self.pred(neg_g, h_customer, etype)
File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don’t have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
Cell In[309], line 45, in RGCN.forward(self, graph, inputs)
43 def forward(self, graph, inputs):
44 # inputs are features of nodes
—> 45 h = {node_type: self.conv1_dict[node_type](graph, inputs)
46 for node_type in inputs.keys()}
47 h = {k: F.relu(v) for k, v in h.items()}
48 h = {node_type: self.conv2_dict[node_type](graph, h[node_type])
49 for node_type in h.keys()}
Cell In[309], line 45, in (.0)
43 def forward(self, graph, inputs):
44 # inputs are features of nodes
—> 45 h = {node_type: self.conv1_dict[node_type](graph, inputs)
46 for node_type in inputs.keys()}
47 h = {k: F.relu(v) for k, v in h.items()}
48 h = {node_type: self.conv2_dict[node_type](graph, h[node_type])
49 for node_type in h.keys()}
File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don’t have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/lib/python3.10/site-packages/dgl/nn/pytorch/hetero.py:210, in HeteroGraphConv.forward(self, g, inputs, mod_args, mod_kwargs)
208 if stype not in inputs:
209 continue
→ 210 dstdata = self._get_module((stype, etype, dtype))(
211 rel_graph,
212 (inputs[stype], inputs[dtype]),
213 *mod_args.get(etype, ()),
214 **mod_kwargs.get(etype, {})
215 )
216 outputs[dtype].append(dstdata)
217 rsts = {}
File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don’t have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/lib/python3.10/site-packages/dgl/nn/pytorch/conv/graphconv.py:450, in GraphConv.forward(self, graph, feat, weight, edge_weight)
447 if self._in_feats > self._out_feats:
448 # mult W first to reduce the feature size for aggregation.
449 if weight is not None:
→ 450 feat_src = th.matmul(feat_src, weight)
451 graph.srcdata[“h”] = feat_src
452 graph.update_all(aggregate_fn, fn.sum(msg=“m”, out=“h”))
RuntimeError: mat1 and mat2 shapes cannot be multiplied (8813x932 and 5641x10)
This is just one of the problems I have encountered. It would be great if I could get some help.
Thanks!