I have a question on the behaviour of blocks in dgl. I’m trying to train a link prediction model using minibatch training, but when I start the training my model cannot find certain nodes. Here’s a reproducible example:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
from dgl import nn as dglnn
from dgl.nn import GATv2Conv
from random import choice
class HeteroDotProductPredictor(nn.Module):
@staticmethod
def forward(
graph: dgl.DGLGraph, h: torch.Tensor, etype
) -> torch.Tensor:
with graph.local_scope():
graph.ndata["h"] = h
graph.apply_edges(fn.u_dot_v("h", "h", "score"), etype=etype)
return graph.edges[etype].data["score"]
class LinkPrediction(nn.Module):
def __init__(
self,
input_nodes_features,
hidden_features,
out_feats,
graph_edges,
edges_to_pred
):
super().__init__()
self.graph_edges = graph_edges
self.edges_to_pred = edges_to_pred
self.conv1 = dglnn.HeteroGraphConv(
{
(source, edge, destination): dglnn.GraphConv(
in_feats=input_nodes_features[source],
out_feats=hidden_features,
activation=nn.ReLU(),
)
for (source, edge, destination) in graph_edges
},
)
self.conv1 = dglnn.HeteroGraphConv(
{
(source, edge, destination): dglnn.GraphConv(
in_feats=hidden_features,
out_feats=out_feats,
activation=nn.ReLU(),
)
for (source, edge, destination) in graph_edges
},
)
self.pred = HeteroDotProductPredictor()
def forward(self, positive_graph, negative_graph, blocks, x):
h = self.conv1(blocks[0], x)
h = self.conv2(blocks[1], h)
pos_scores = pos = torch.cat(
[self.pred(positive_graph, h, etype) for etype in self.edges_to_pred],
0
)
neg_scores = pos = torch.cat(
[self.pred(negative_graph, h, etype) for etype in self.edges_to_pred],
0
)
def create_dataloader(graph, indices, sampler, batch_size: int = 1024, device="cpu"):
return dgl.dataloading.DataLoader(
# The following arguments are specific to DataLoader.
graph, # The graph
indices, # The edges to iterate over
sampler, # The neighbor sampler
device=device, # Put the MFGs on CPU or GPU
use_uva=False, # use UVA to speed loading
# The following arguments are inherited from PyTorch DataLoader.
batch_size=batch_size, # Batch size
shuffle=True, # Whether to shuffle the nodes for every epoch
drop_last=False, # Whether to drop the last incomplete batch
num_workers=0, # Number of sampler processes
)
def compute_hinge_loss(pos_score, neg_score, margin=1, **kwargs):
# an example hinge loss
n = pos_score.shape[0]
return (neg_score.view(n, -1) - pos_score.view(n, -1) + margin).clamp(min=0).mean()
def generate_features_dict(graph):
return {
"user": graph.nodes["user"].data["features"],
"event": graph.nodes["event"].data["features"],
}
def compute_average_precision(
pos_score: torch.Tensor, neg_score: torch.Tensor, **kwargs
):
scores = torch.cat([pos_score, neg_score]).squeeze().detach()
labels = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
return average_precision_score(labels.cpu(), scores.cpu())
# Define the number of nodes for each type
num_users = 15
num_events = 200
# Create a heterograph
graph = dgl.heterograph(
{
("user", "friend", "user"): (
[choice(list(range(num_users))) for _ in range(20)],
[choice(list(range(num_users))) for _ in range(20)],
),
("user", "participate", "event"): (
[choice(list(range(num_users))) for _ in range(10)],
[choice(list(range(num_events))) for _ in range(10)],
),
("user", "like", "event"): (
[choice(list(range(num_users))) for _ in range(500)],
[choice(list(range(num_events))) for _ in range(500)],
),
("user", "follow", "user"): (
[choice(list(range(num_users))) for _ in range(15)],
[choice(list(range(num_users))) for _ in range(15)],
),
}
)
graph.nodes["user"].data["features"] = torch.randint(18, 35, (num_users, 10))
graph.nodes["event"].data["features"] = torch.ones(num_events, 100)
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler=dgl.dataloading.MultiLayerFullNeighborSampler(2),
negative_sampler=dgl.dataloading.negative_sampler.Uniform(3),
)
edges_dict = {
("user", "friend", "user"): graph.edges(
etype=("user", "friend", "user"), form="eid"
),
("user", "participate", "event"): graph.edges(
etype=("user", "participate", "event"), form="eid"
),
("user", "like", "event"): graph.edges(
etype=("user", "like", "event"), form="eid"
),
("user", "follow", "user"): graph.edges(
etype=("user", "follow", "user"), form="eid"
),
}
dataloader = create_dataloader(graph, edges_dict, sampler, batch_size=4096)
input_features_dims = {
"user": graph.nodes["user"].data["features"].shape[1],
"event": graph.nodes["event"].data["features"].shape[1],
}
graph_edges = [
("user", "user", "follow"),
("user", "user", "friend"),
("user", "event", "like"),
("user", "event", "participate")
]
model = LinkPrediction(
input_nodes_features=input_features_dims,
hidden_features=8,
out_feats=4,
graph_edges=graph_edges,
edges_to_pred=["user", "partecipate", "event"],
)
for epoch in range(30):
model.train()
pos_scores = []
neg_scores = []
losses = []
print(epoch)
for input_nodes, pos_graph, neg_graph, blocks in tqdm(dataloader):
input_features = generate_features_dict(blocks[0])
pos_score, neg_score = model(
positive_graph=pos_graph,
negative_graph=neg_graph,
blocks=blocks,
x=input_features,
)
loss = compute_hinge_loss(pos_score, neg_score)
pos_scores.extend(pos_score.squeeze().tolist())
neg_scores.extend(neg_score.squeeze().tolist())
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
metric = compute_average_precision(
pos_score=torch.tensor(pos_scores), neg_score=torch.tensor(neg_scores)
)
print(sum(losses) / len(losses), metric)
after running this code I get the following error:
KeyError Traceback (most recent call last)
Cell In[12], line 9
7 for input_nodes, pos_graph, neg_graph, blocks in tqdm(dataloader):
8 input_features = generate_features_dict(blocks[0])
----> 9 pos_score, neg_score = model(
10 positive_graph=pos_graph,
11 negative_graph=neg_graph,
12 blocks=blocks,
13 x=input_features,
14 )
15 loss = compute_hinge_loss(pos_score, neg_score)
16 pos_scores.extend(pos_score.squeeze().tolist())
File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
Cell In[5], line 37
36 def forward(self, positive_graph, negative_graph, blocks, x):
---> 37 h = self.conv1(blocks[0], x)
38 h = self.conv2(blocks[1], h)
39 pos_scores = pos = torch.cat(
40 [self.pred(positive_graph, h, etype) for etype in self.edges_to_pred],
41 0
42 )
File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/dgl/nn/pytorch/hetero.py:198, in HeteroGraphConv.forward(self, g, inputs, mod_args, mod_kwargs)
196 if stype not in src_inputs or dtype not in dst_inputs:
197 continue
--> 198 dstdata = self._get_module((stype, etype, dtype))(
199 rel_graph,
200 (src_inputs[stype], dst_inputs[dtype]),
201 *mod_args.get(etype, ()),
202 **mod_kwargs.get(etype, {})
203 )
204 outputs[dtype].append(dstdata)
205 else:
File ~/.local/share/virtualenvs/test_creazione_grafo-lMMVGt1W/lib/python3.11/site-packages/dgl/nn/pytorch/hetero.py:156, in HeteroGraphConv._get_module(self, etype)
153 if isinstance(etype, tuple):
154 # etype is canonical
155 _, etype, _ = etype
--> 156 return self.mod_dict[etype]
157 raise KeyError("Cannot find module with edge type %s" % etype)
KeyError: 'follow'
It seems like the first convolution cannot find the feature relative to the “follow” node. How is it possible?