Layers cannot find nodes in the graph

I have a question on the behaviour of blocks in dgl. I’m trying to train a link prediction model using minibatch training, but when I start the training my model cannot find certain nodes. Here’s a reproducible example:

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
from dgl import nn as dglnn
from dgl.nn import GATv2Conv
from random import choice

class HeteroDotProductPredictor(nn.Module):
    @staticmethod
    def forward(
        graph: dgl.DGLGraph, h: torch.Tensor, etype
    ) -> torch.Tensor:
        with graph.local_scope():
            graph.ndata["h"] = h
            graph.apply_edges(fn.u_dot_v("h", "h", "score"), etype=etype)
            return graph.edges[etype].data["score"]


class LinkPrediction(nn.Module):
    def __init__(
        self, 
        input_nodes_features, 
        hidden_features, 
        out_feats, 
        graph_edges, 
        edges_to_pred
    ):
        super().__init__()
        self.graph_edges = graph_edges
        self.edges_to_pred = edges_to_pred
        self.conv1 = dglnn.HeteroGraphConv(
            {
                (source, edge, destination): dglnn.GraphConv(
                    in_feats=input_nodes_features[source], 
                    out_feats=hidden_features,
                    activation=nn.ReLU(),
                )
                for (source, edge, destination) in graph_edges
            },
        )
        self.conv1 = dglnn.HeteroGraphConv(
            {
                (source, edge, destination): dglnn.GraphConv(
                    in_feats=hidden_features,
                    out_feats=out_feats,
                    activation=nn.ReLU(),
                )
                for (source, edge, destination) in graph_edges
            },
        )
        
        self.pred = HeteroDotProductPredictor()
        
    def forward(self, positive_graph, negative_graph, blocks, x):
        h = self.conv1(blocks[0], x)
        h = self.conv2(blocks[1], h)
        pos_scores = pos = torch.cat(
            [self.pred(positive_graph, h, etype) for etype in self.edges_to_pred], 
            0
        )
        neg_scores = pos = torch.cat(
            [self.pred(negative_graph, h, etype) for etype in self.edges_to_pred], 
            0
        )


def create_dataloader(graph, indices, sampler, batch_size: int = 1024, device="cpu"):
    return dgl.dataloading.DataLoader(
        # The following arguments are specific to DataLoader.
        graph,  # The graph
        indices,  # The edges to iterate over
        sampler,  # The neighbor sampler
        device=device,  # Put the MFGs on CPU or GPU
        use_uva=False,  # use UVA to speed loading
        # The following arguments are inherited from PyTorch DataLoader.
        batch_size=batch_size,  # Batch size
        shuffle=True,  # Whether to shuffle the nodes for every epoch
        drop_last=False,  # Whether to drop the last incomplete batch
        num_workers=0,  # Number of sampler processes
    )


def compute_hinge_loss(pos_score, neg_score, margin=1, **kwargs):
    # an example hinge loss
    n = pos_score.shape[0]
    return (neg_score.view(n, -1) - pos_score.view(n, -1) + margin).clamp(min=0).mean()


def generate_features_dict(graph):
    return {
        "user": graph.nodes["user"].data["features"],
        "event": graph.nodes["event"].data["features"],
    }


def compute_average_precision(
    pos_score: torch.Tensor, neg_score: torch.Tensor, **kwargs
):
    scores = torch.cat([pos_score, neg_score]).squeeze().detach()

    labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
    return average_precision_score(labels.cpu(), scores.cpu())

# Define the number of nodes for each type
num_users = 15
num_events = 200

# Create a heterograph
graph = dgl.heterograph(
    {
        ("user", "friend", "user"): (
            [choice(list(range(num_users))) for _ in range(20)],
            [choice(list(range(num_users))) for _ in range(20)],
        ),
        ("user", "participate", "event"): (
            [choice(list(range(num_users))) for _ in range(10)],
            [choice(list(range(num_events))) for _ in range(10)],
        ),
        ("user", "like", "event"): (
            [choice(list(range(num_users))) for _ in range(500)],
            [choice(list(range(num_events))) for _ in range(500)],
        ),
        ("user", "follow", "user"): (
            [choice(list(range(num_users))) for _ in range(15)],
            [choice(list(range(num_users))) for _ in range(15)],
        ),
    }
)

graph.nodes["user"].data["features"] = torch.randint(18, 35, (num_users, 10))
graph.nodes["event"].data["features"] = torch.ones(num_events, 100)

sampler = dgl.dataloading.as_edge_prediction_sampler(
    sampler=dgl.dataloading.MultiLayerFullNeighborSampler(2),
    negative_sampler=dgl.dataloading.negative_sampler.Uniform(3),
)

edges_dict = {
    ("user", "friend", "user"): graph.edges(
        etype=("user", "friend", "user"), form="eid"
    ),
    ("user", "participate", "event"): graph.edges(
        etype=("user", "participate", "event"), form="eid"
    ),
    ("user", "like", "event"): graph.edges(
        etype=("user", "like", "event"), form="eid"
    ),
    ("user", "follow", "user"): graph.edges(
        etype=("user", "follow", "user"), form="eid"
    ),
}

dataloader = create_dataloader(graph, edges_dict, sampler, batch_size=4096)

input_features_dims = {
    "user": graph.nodes["user"].data["features"].shape[1],
    "event": graph.nodes["event"].data["features"].shape[1],
}

graph_edges = [
    ("user", "user", "follow"), 
    ("user", "user", "friend"), 
    ("user", "event", "like"), 
    ("user", "event", "participate")
]

model = LinkPrediction(
    input_nodes_features=input_features_dims,
    hidden_features=8,
    out_feats=4,
    graph_edges=graph_edges,
    edges_to_pred=["user", "partecipate", "event"],
)

for epoch in range(30):
    model.train()
    pos_scores = []
    neg_scores = []
    losses = []
    print(epoch)
    for input_nodes, pos_graph, neg_graph, blocks in tqdm(dataloader):
        input_features = generate_features_dict(blocks[0])
        pos_score, neg_score = model(
            positive_graph=pos_graph,
            negative_graph=neg_graph,
            blocks=blocks,
            x=input_features,
        )
        loss = compute_hinge_loss(pos_score, neg_score)
        pos_scores.extend(pos_score.squeeze().tolist())
        neg_scores.extend(neg_score.squeeze().tolist())

        opt.zero_grad()
        loss.backward()
        opt.step()
        losses.append(loss.item())
    metric = compute_average_precision(
        pos_score=torch.tensor(pos_scores), neg_score=torch.tensor(neg_scores)
    )
    print(sum(losses) / len(losses), metric)

after running this code I get the following error:

KeyError                                  Traceback (most recent call last)
Cell In[12], line 9
      7 for input_nodes, pos_graph, neg_graph, blocks in tqdm(dataloader):
      8     input_features = generate_features_dict(blocks[0])
----> 9     pos_score, neg_score = model(
     10         positive_graph=pos_graph,
     11         negative_graph=neg_graph,
     12         blocks=blocks,
     13         x=input_features,
     14     )
     15     loss = compute_hinge_loss(pos_score, neg_score)
     16     pos_scores.extend(pos_score.squeeze().tolist())

File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Cell In[5], line 37
     36 def forward(self, positive_graph, negative_graph, blocks, x):
---> 37     h = self.conv1(blocks[0], x)
     38     h = self.conv2(blocks[1], h)
     39     pos_scores = pos = torch.cat(
     40         [self.pred(positive_graph, h, etype) for etype in self.edges_to_pred], 
     41         0
     42     )

File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/dgl/nn/pytorch/hetero.py:198, in HeteroGraphConv.forward(self, g, inputs, mod_args, mod_kwargs)
    196         if stype not in src_inputs or dtype not in dst_inputs:
    197             continue
--> 198         dstdata = self._get_module((stype, etype, dtype))(
    199             rel_graph,
    200             (src_inputs[stype], dst_inputs[dtype]),
    201             *mod_args.get(etype, ()),
    202             **mod_kwargs.get(etype, {})
    203         )
    204         outputs[dtype].append(dstdata)
    205 else:

File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/dgl/nn/pytorch/hetero.py:156, in HeteroGraphConv._get_module(self, etype)
    153 if isinstance(etype, tuple):
    154     # etype is canonical
    155     _, etype, _ = etype
--> 156     return self.mod_dict[etype]
    157 raise KeyError("Cannot find module with edge type %s" % etype)

KeyError: 'follow'

It seems like the first convolution cannot find the feature relative to the “follow” node. How is it possible?

This part should be of order (source_type, edge_type, destination_type), in other words,

graph_edges = [
    ("user", "follow", "user"), 
    ("user", "friend", "user"), 
    ("user", "like", "event"), 
    ("user", "participate", "event")
]
1 Like