Trouble With Custom Negative Sampler for Edge Type


I’m trying to train a Link Prediction model on a Heterograph and use custom negative edges. The negative edges are only for a certain edge type. Everything else should be trained as normal.

I created a CustomNegativeSampler as such. You can see that I pass in ids_to_push which is a dict consisting of IDs of the nodes that should be negative edges. If the edge type is user_follows, this dict is indexed and one of the negative edges for this source node is chosen.

"""Negative samplers"""
from import Mapping
from dgl import backend as F
from dgl.dataloading.negative_sampler import _BaseNegativeSampler
import torch
import random

class _BaseNegativeSamplerCustom(object):
    def _generate(self, g, eids, canonical_etype):
        raise NotImplementedError

    def __call__(self, g, eids):
        if isinstance(eids, Mapping):
            eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
            neg_pair = {k: self._generate(g, v, k) for k, v in eids.items()}
            assert len(g.etypes) == 1, \
                'please specify a dict of etypes and ids for graphs with multiple edge types'
            neg_pair = self._generate(g, eids, g.canonical_etypes[0])

        return neg_pair

class Uniform_Unless_Given(_BaseNegativeSamplerCustom):
    def __init__(self, k, ids_to_push=None):
        self.k = k
        self.ids_to_push = ids_to_push

    def _generate(self, g, eids, canonical_etype):

        _, c_etype, vtype = canonical_etype
        shape = F.shape(eids)
        dtype = F.dtype(eids)
        ctx = F.context(eids)
        # we want to generate k negative examples for each positive
        shape = (shape[0] * self.k,)
        # return the source and destination ID given the edges
        src, _ = g.find_edges(eids, etype=canonical_etype)
        # repeat it to get k of it
        src = F.repeat(src, self.k, 0)
        dst_placeholder = F.randint(shape, dtype, ctx, 0, g.number_of_nodes(vtype))
        output_tensor = []
        for src_value, dst_value in zip(src, dst_placeholder):
            src_value =
            dst_value =
            if self.ids_to_push is not None and src_value in self.ids_to_push and c_etype == 'user_follows':
                # choose one of the ids at random
        dst = torch.LongTensor(output_tensor)

        return src, dst

I then train this model for link prediction in the normal way as discussed in the docs:

negative_sampler_to_use = Uniform_Unless_Given(5, negative_samples)
train_eid_dict = {canonical_etype: torch.arange(g.num_edges(canonical_etype[1]), dtype=torch.int64) for canonical_etype in g.canonical_etypes}
dataloader_lp_for_communities = dgl.dataloading.EdgeDataLoader(overall_graph._g[0], train_eid_dict, sampler, negative_sampler=negative_sampler_to_use, batch_size=args.batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=args.num_workers)

When running it, I get an error in my Score Predictor function:

class HeteroScorePredictor(nn.Module):
    def forward(self, edge_subgraph, x):
        with edge_subgraph.local_scope():
            edge_subgraph.ndata['h'] = x
            for etype in edge_subgraph.canonical_etypes:
                if edge_subgraph.num_edges(etype) <= 0:
                edge_subgraph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
            return edge_subgraph.edata['score']

Here is the error:

Traceback (most recent call last):
  File "/home_directory/lib/python3.6/multiprocessing/", line 258, in _bootstrap
  File "/home_directory/lib/python3.6/multiprocessing/", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "”, line 5189, in running_code
    pos_score, neg_score = model(blocks, node_features, g=positive_graph, neg_g=negative_graph)
  File "/home_directory/lib/python3.6/site-packages/torch/nn/modules/", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home_directory/lib/python3.6/site-packages/torch/nn/parallel/", line 458, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home_directory/lib/python3.6/site-packages/torch/nn/modules/", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "", line 338, in forward
    pos_score = self.pred(g, x) 
  File "/home_directory/lib/python3.6/site-packages/torch/nn/modules/", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "", line 320, in forward
    edge_subgraph.apply_edges(dgl.function.u_dot_v('h', 'h', 'score'), etype=etype)
  File "/home_directory/lib/python3.6/site-packages/dgl/", line 4423, in apply_edges
    edata = core.invoke_gsddmm(g, func)
  File "/home_directory/lib/python3.6/site-packages/dgl/", line 239, in invoke_gsddmm
    y = alldata[func.rhs][func.rhs_field]
  File "/home_directory/lib/python3.6/site-packages/dgl/", line 66, in __getitem__
    return self._graph._get_n_repr(self._ntid, self._nodes)[key]
  File "/home_directory/lib/python3.6/site-packages/dgl/", line 393, in __getitem__
    return self._columns[name].data
yError: 'h'

I’m not sure what the problem is here, and I’m not sure how to debug it.

Thanks in advance for the help!

  1. the error message seems to be incomplete. h key cannot be found?
  2. this happens in the forward of neg_graph?

could you check whether all the src and dst returned from customized neg sampler have the feature h? or just try with always a fixed src/dst like return 0 1 for debug.

The error message is: KeyError: 'h'. It doesn’t say cannot be found.

It happens in the forward of the positive graph. Below is the code for the forward prediction and the error happens on this line: pos_score = self.pred(g, x):

def forward(self, blocks, x, g, neg_g):

    # positive_graph, negative_graph, input_features, ('source', 'has_follower', 'user')
    # run both part of the graph through it given the input feature
    x = self.sage(blocks, x)
    if g is None:
        return x
    pos_score = self.pred(g, x)
    neg_score = self.pred(neg_g, x)
    return pos_score, neg_score 

I printed out the node features:

{'source': tensor([[-1.4817e+00, -1.8118e-01,  1.4381e-01,  ...,  1.5922e-03,
          2.0126e-03,  3.1638e-01],
        [-6.6626e-01, -1.9062e-01, -3.1084e-01,  ...,  1.6274e-04,
          0.0000e+00,  9.4086e-01],
        [ 4.6148e-02,  4.7933e-01, -4.0269e-02,  ...,  2.5043e-03,
          2.7407e-03,  7.0756e-01],
        [ 4.4293e-01,  7.2503e-01,  6.5911e-01,  ...,  6.8932e-01,
          3.1554e-02,  3.5854e-01],
        [ 5.8352e-01,  2.7026e-01,  1.1440e-01,  ...,  1.4909e-01,
          8.1176e-01,  9.2421e-01],
        [ 2.8641e-01,  1.4279e-01,  3.8616e-01,  ...,  6.1561e-01,
          4.7736e-01,  4.5671e-01]], device='cuda:0'), 'user': tensor([[-0.5143, -0.2436,  0.7256,  ...,  0.0000,  0.0000,  0.0000],
        [-0.1877,  0.0263, -0.3932,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3074,  0.3796,  0.1658,  ...,  0.0000,  0.0000,  0.0000],
        [-1.6278, -0.3349,  0.8372,  ...,  0.0000,  0.0000,  0.0000],
        [-0.3414,  0.9643,  0.2000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.2282, -0.1879,  0.0570,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0'), 'good_source': tensor([], device='cuda:0', size=(0, 778)), 'article': tensor([[-0.0046, -0.3887,  0.8251,  ...,  0.0000,  0.0000,  0.0000],
        [-0.1073,  0.2448, -0.4397,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2641, -0.1296, -0.5983,  ...,  0.0000,  0.0000,  0.0000],
        [-0.1708,  0.1430,  0.1071,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.3968, -0.2619,  1.3415,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2960,  0.2684,  0.1726,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0'), 'low_source': tensor([], device='cuda:0', size=(0, 778)), 'mid_source': tensor([], device='cuda:0', size=(0, 778))}

And the positive graph:

Graph(num_nodes={'article': 21, 'dissimilar_special': 2, 'good_source': 0, 'low_source': 0, 'mid_source': 0, 'similar_special': 0, 'source': 61, 'user': 773},
      num_edges={('article', 'article_dissimilar_special', 'dissimilar_special'): 0, ('article', 'article_similar_special', 'similar_special'): 0, ('article', 'does_not_talk_similar_article', 'article'): 0, ('article', 'has_talker', 'user'): 0, ('article', 'is_published_by', 'source'): 3, ('article', 'talks_similar_article', 'article'): 0, ('article', 'talks_similar_article_neg', 'article'): 0, ('article', 'will_not_be_propagated_by', 'user'): 0, ('dissimilar_special', 'dissimilar_special_article', 'article'): 0, ('dissimilar_special', 'dissimilar_special_source', 'source'): 0, ('dissimilar_special', 'dissimilar_special_user', 'user'): 0, ('good_source', 'links', 'source'): 0, ('low_source', 'links_low', 'source'): 0, ('mid_source', 'links_mid', 'source'): 0, ('similar_special', 'similar_special_article', 'article'): 0, ('similar_special', 'similar_special_source', 'source'): 0, ('similar_special', 'similar_special_user', 'user'): 0, ('source', 'has_article', 'article'): 3, ('source', 'has_follower', 'user'): 0, ('source', 'has_good', 'good_source'): 0, ('source', 'has_low', 'low_source'): 0, ('source', 'has_mid', 'mid_source'): 0, ('source', 'source_dissimilar_special', 'dissimilar_special'): 0, ('source', 'source_similar_special', 'similar_special'): 0, ('source', 'talks_similar_source', 'source'): 0, ('user', 'connects_with_contr', 'user'): 0, ('user', 'does_not_follow', 'user'): 0, ('user', 'does_not_propagate', 'article'): 0, ('user', 'follows', 'source'): 7, ('user', 'is_influenced_by_contr', 'user'): 0, ('user', 'talks_about', 'article'): 0, ('user', 'user_dissimilar_special', 'dissimilar_special'): 1, ('user', 'user_follows', 'user'): 111, ('user', 'user_similar_special', 'similar_special'): 0},
      metagraph=[('article', 'dissimilar_special', 'article_dissimilar_special'), ('article', 'similar_special', 'article_similar_special'), ('article', 'article', 'does_not_talk_similar_article'), ('article', 'article', 'talks_similar_article'), ('article', 'article', 'talks_similar_article_neg'), ('article', 'user', 'has_talker'), ('article', 'user', 'will_not_be_propagated_by'), ('article', 'source', 'is_published_by'), ('dissimilar_special', 'article', 'dissimilar_special_article'), ('dissimilar_special', 'source', 'dissimilar_special_source'), ('dissimilar_special', 'user', 'dissimilar_special_user'), ('similar_special', 'article', 'similar_special_article'), ('similar_special', 'source', 'similar_special_source'), ('similar_special', 'user', 'similar_special_user'), ('user', 'user', 'connects_with_contr'), ('user', 'user', 'does_not_follow'), ('user', 'user', 'is_influenced_by_contr'), ('user', 'user', 'user_follows'), ('user', 'article', 'does_not_propagate'), ('user', 'article', 'talks_about'), ('user', 'source', 'follows'), ('user', 'dissimilar_special', 'user_dissimilar_special'), ('user', 'similar_special', 'user_similar_special'), ('source', 'article', 'has_article'), ('source', 'user', 'has_follower'), ('source', 'good_source', 'has_good'), ('source', 'low_source', 'has_low'), ('source', 'mid_source', 'has_mid'), ('source', 'dissimilar_special', 'source_dissimilar_special'), ('source', 'similar_special', 'source_similar_special'), ('source', 'source', 'talks_similar_source'), ('good_source', 'source', 'links'), ('low_source', 'source', 'links_low'), ('mid_source', 'source', 'links_mid')])

The graph is trained as such:

for iteration, (input_nodes, positive_graph, negative_graph, blocks) in tqdm(enumerate(dataloader_lp_for_communities)):
    blocks = [ for b in blocks]

    positive_graph =
    negative_graph =
    node_features = get_features_given_blocks(g, args, blocks, graph_style)

    # forward propagation by using all nodes and extracting the user embeddings
    pos_score, neg_score = model(blocks, node_features, g=positive_graph, neg_g=negative_graph)

Thanks for looking into it!