KeyError: 'x' in GNN for Link Prediction with Neighborhood Sampling

Hi.

I’m follwing this 6.3 Training GNN for Link Prediction with Neighborhood Sampling — DGL 0.6.1 documentation

I have a heterograph with this shape.

Graph(num_nodes={'post': 11911, 'user': 9482},
      num_edges={('user', 'like', 'post'): 363984},
      metagraph=[('user', 'post', 'like')])

When trying the training loop, I get this error:

KeyError: 'x'

I got the intuition it is because I have just one edge type, because this same architecture used to work with another graph with two edge types.

What should I change? For now I fixed it by adding the reverse edges, but I assume the is a proper way to do it.

Thanks!!

Predictor, model:

class ScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['x'] = x
            for etype in edge_subgraph.canonical_etypes:
                edge_subgraph.apply_edges(
                    dgl.function.u_dot_v('x', 'x', 'score'), etype=etype)
            return edge_subgraph.edata['score']

class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_classes,
                 etypes):
        super().__init__()
        self.rgcn = StochasticTwoLayerRGCN(
            in_features, hidden_features, out_features, etypes)
        self.pred = ScorePredictor()

    def forward(self, positive_graph, negative_graph, blocks, x):
        x = self.rgcn(blocks, x)
        pos_score = self.pred(positive_graph, x)
        neg_score = self.pred(negative_graph, x)
        return pos_score, neg_score

Training loop:

for epoch in range(n_epochs):
  
  for input_nodes, positive_graph, negative_graph, blocks in dataloader:
      blocks = [b for b in blocks]
      input_features = blocks[0].srcdata['feats']
      pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)
      loss = compute_loss(pos_score, neg_score, g.canonical_etypes)
      opt.zero_grad()
      loss.backward()
      opt.step()
  print(loss.item())

Could you please post the entire error stack?

This happens, as stated by @mufeili in In the RGCN forward function. when the second graphConv function is applied, the 'h' suddenly becomes a zero tensor. Not sure why it becomes like this - #3 by fajim123, because “After the first convolution layer, you have updated node representations for item-typed nodes only since you only have edges from user-typed nodes to item-typed nodes. Now for the second convolution layer, since you do not provide node representations for the user-typed nodes, the HeteroGraphConv object skips the computation and returns an empty dictionary.”

1 Like

Hi. Another question. Since I have created the inverse edge, do I need to train the model on both links, the original edge and its reverse, or is it enough with doing it just one way?

Do you mean reverse edges or both reverse edges and reverse edge types? What are the edge types for which you want to perform link prediction?

I had an original link: (‘user’, ‘like’, ‘post’) and I added (‘post’, ‘liked’, ‘user’) to solve previous mentioned problem. I want to predict new edges between users and posts. For inference I undertsand it is basically the same, but I am not sure if I should train the model on both edges or just on one. Thanks!

This might depend on how you performed link prediction with updated node representations. If you take the dot product of pairs of node representations, then it makes no difference whether you include these reverse edges.

1 Like

One last question.

My initial embeddings are BERT-extracted for ‘post’ nodes, but are initialized randomly for ‘user’ nodes. Since I will train on just one edge, would it be better to train the model on the (‘user’, ‘like’, ‘post’) edge or on the (‘post’, ‘liked’, ‘user’), or maybe it doesn’t make any difference?

Thanks.

It doesn’t matter since you will use the same graph for message passing.

1 Like

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.