In the RGCN forward function. when the second graphConv function is applied, the 'h' suddenly becomes a zero tensor. Not sure why it becomes like this

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-38-e89b4657d4d0> in <module>
      7 for epoch in range(10):
      8     negative_graph = construct_negative_graph(G, k, ('Customer', 'Pays', 'Merchant'))
----> 9     pos_score, neg_score = model(G, negative_graph, node_features, ('Customer', 'Pays', 'Merchant'))
     10     loss = compute_loss(pos_score, neg_score)
     11     opt.zero_grad()

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-36-6aab2b2b5fac> in forward(self, g, neg_g, x, etype)
      6     def forward(self, g, neg_g, x, etype):
      7         h = self.sage(g, x)
----> 8         return self.pred(g, h, etype), self.pred(neg_g, h, etype)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-34-2ba1e2ddc2b2> in forward(self, graph, h, etype)
      5         with graph.local_scope():
      6             graph.ndata['h'] = h
----> 7             graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
      8             return graph.edges[etype].data['score']

/opt/conda/lib/python3.7/site-packages/dgl/heterograph.py in apply_edges(self, func, edges, etype, inplace)
   4303             if not is_all(eid):
   4304                 g = g.edge_subgraph(eid, preserve_nodes=True)
-> 4305             edata = core.invoke_gsddmm(g, func)
   4306         else:
   4307             edata = core.invoke_edge_udf(g, eid, etype, func)

/opt/conda/lib/python3.7/site-packages/dgl/core.py in invoke_gsddmm(graph, func)
    202     alldata = [graph.srcdata, graph.dstdata, graph.edata]
    203     if isinstance(func, fn.BinaryMessageFunction):
--> 204         x = alldata[func.lhs][func.lhs_field]
    205         y = alldata[func.rhs][func.rhs_field]
    206         op = getattr(ops, func.name)

/opt/conda/lib/python3.7/site-packages/dgl/view.py in __getitem__(self, key)
     64             return ret
     65         else:
---> 66             return self._graph._get_n_repr(self._ntid, self._nodes)[key]
     67 
     68     def __setitem__(self, key, val):

/opt/conda/lib/python3.7/site-packages/dgl/frame.py in __getitem__(self, name)
    391             Column data.
    392         """
--> 393         return self._columns[name].data
    394 
    395     def __setitem__(self, name, data):

KeyError: 'h'
1 Like

Here are the previous codes

import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h


class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)


def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()


k = 5
model = Model(2, 20, 5, G.etypes)
merchant_feats = G.nodes['Merchant'].data['feat']
customer_feats = G.nodes['Customer'].data['feat']
node_features = {'Merchant': merchant_feats, 'Customer': customer_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    negative_graph = construct_negative_graph(G, k, ('Customer', 'Pays', 'Merchant'))
    pos_score, neg_score = model(G, negative_graph, node_features, ('Customer', 'Pays', 'Merchant'))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

And shape of the features used

G.nodes[‘Merchant’].data[‘feat’].shape = torch.Size([5, 2])

G.nodes[‘Customer’].data[‘feat’].shape = torch.Size([10, 2])

G.edata[‘feat’].shape = torch.Size([10, 2])

You mean h becomes zero from nonzero in the line of code below?

h = self.conv2(graph, h)
1 Like

Not that code but in the forward function

Can you check when does it happen in the forward function of self.sage?

Apologies, but you’re correct. The empty tensor occurs in this line of code

Are they non-zero after you applied ReLU?

Yup they are non-zero after this line

This code gives the same error and should be reproducible from your end

import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
import dgl.function as fn

from dgl import DGLGraph

import numpy as np
import torch

n_users = 1000
n_items = 500
n_follows = 3000
n_clicks = 5000
n_hetero_features = 2



follow_src = np.random.randint(0, n_users, n_follows)
follow_dst = np.random.randint(0, n_users, n_follows)
click_src = np.random.randint(0, n_users, n_clicks)
click_dst = np.random.randint(0, n_items, n_clicks)


hetero_graph = dgl.heterograph({
    ('user', 'click', 'item'): (click_src, click_dst)
    })

hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)
hetero_graph.nodes['item'].data['feature'] = torch.randn(n_items, n_hetero_features)


class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        print(1,h)
        h = {k: F.relu(v) for k, v in h.items()}
        print(2,h)
        h = self.conv2(graph, h)
        print(3,h)
        return h

class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score.unsqueeze(1) + neg_score.view(n_edges, -1)).clamp(min=0).mean()

k = 5
model = Model(2, 20, 5, hetero_graph.etypes)
user_feats = hetero_graph.nodes['user'].data['feature']
item_feats = hetero_graph.nodes['item'].data['feature']
node_features = {'user': user_feats, 'item': item_feats}
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    negative_graph = construct_negative_graph(hetero_graph, k, ('user', 'click', 'item'))
    pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ('user', 'click', 'item'))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

Your graph has a single relation ('user', 'click', 'item'). Initially you have node features for both user-typed nodes and item-typed nodes. After the first convolution layer, you have updated node representations for item-typed nodes only since you only have edges from user-typed nodes to item-typed nodes. Now for the second convolution layer, since you do not provide node representations for the user-typed nodes, the HeteroGraphConv object skips the computation and returns an empty dictionary.

3 Likes

So If the graph is bi-directional or if each of the node types has an in-degree and out-degree then there will not be an empty dictionary?

So If the graph is bi-directional or if each of the node types has an in-degree and out-degree then there will not be an empty dictionary?

Yes.

1 Like

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.