Batched negative sampling with GraphDataLoader for LinkPrediciton

Hi, I have some questions about the link prediction task, which is extended from the discussion here. I successfully trained on a single graph via the tutorial document. Now, I want to train on batches of graphs, where the negative sampling needs to be done on batch:

import torch
import dgl
import dgl.function as fn
import torch.nn.functional as F
from dgl.nn.pytorch import HeteroGraphConv
from dgl.nn.pytorch.conv import GraphConv

class HeteroDotProductPredictor(torch.nn.Module):
    def forward(self, graph, h, etype):
        with graph.local_scope():
            graph.ndata['h'] = h['face']  # node type "face"
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']

class Model(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, rel_names):
        self.sage = GNN(in_features, hidden_features, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()
    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

class GNN(torch.nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        self.conv1 = HeteroGraphConv({
            rel: GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='mean')
        self.conv2 = HeteroGraphConv({
            rel: GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        # print(inputs["face"].shape)
        h = self.conv1(graph, inputs)         
        #h["face"] = h["face"].reshape(-1, 80)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

def construct_negative_graph(graph, k, etype):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,))
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes})

# train, val, test = 700, 100, 100 dgl graphs
dataloader = dgl.dataloading.GraphDataLoader(train, batch_size=128, shuffle=True, drop_last=False, num_workers=4)
k = 30
model = Model(train[0].ndata["x"].shape[1], 20, 1, train[0].etypes)
opt = torch.optim.Adam(model.parameters())
for epoch in range(2):
    running_loss = 0.
    for i, hetero_graph in enumerate(dataloader, 0):
        node_features = hetero_graph.nodes["face"].data["x"].float()
        node_features = {"face": node_features}
        # TODO: negative sampling on batch
        negative_graph = construct_negative_graph(hetero_graph, k, ("face", "type3", "face"))
        pos_score, neg_score = model(hetero_graph, negative_graph, node_features, ("face", "type3", "face"))
        loss = compute_loss(pos_score, neg_score)
        print(f"iter {i} loss {loss}")
        running_loss += loss.item()
    print(f'{epoch + 1} loss: {running_loss / 128:.5f}')

My questions are:

  1. How should I define the construct_negative_graph() to adapt to batched graphs obtained in the GraphDataLoader? For my case, I don’t need to do neighbor sampling since the graphs are relatively small, so I only need the negative sampling.
  2. How to extend the training loop to train for other edge types? My graphs have only one node type and multiple edge types. Should I construct pos/neg graphs for individual edge types and sum up the loss during the training? For example:
pos_score1, neg_score1 = #... neg sampling for edge_type_1
pos_score2, neg_score2 = #... neg sampling for edge_type_2
pos_score3, neg_score3 = #... neg sampling for edge_type_3

loss1 = compute_loss(pos_score1, neg_score1)
loss2 = compute_loss(pos_score2, neg_score2)
loss3 = compute_loss(pos_score3, neg_score3)

loss = loss1 + loss2 + loss3

Thank you!

  1. batched_graph cannot work with existing construct_negative_graph()?
  2. Yes, just compute loss on each edge type and sum up for backward.