Question about negative sampling

When I look at how the nodes for negative sampling are generated, I see that in the dgl source code it seems to be randomly generated? I’m not sure how to exclude those nodes that have connection edges with the source node?

from pytorch/graphsage/link_pred.py

sampler = as_edge_prediction_sampler(
        sampler,
        exclude="reverse_id",
        reverse_eids=reverse_eids,
        negative_sampler=negative_sampler.Uniform(1),
    )

to dgl’s negative_sampler.py

def _generate(self, g, eids, canonical_etype):
        _, _, vtype = canonical_etype
        shape = F.shape(eids)
        dtype = F.dtype(eids)
        ctx = F.context(eids)
        shape = (shape[0] * self.k,)
        src, _ = g.find_edges(eids, etype=canonical_etype)
        src = F.repeat(src, self.k, 0)
        dst = F.randint(shape, dtype, ctx, 0, g.num_nodes(vtype))
        return src, dst

Hi @bear , sorry for the late notice, it seems you want true negative, just show this negative sampler: https://github.com/dmlc/dgl/blob/b8886900837e3fc73972215b7c5e9b3e127acbfc/python/dgl/dataloading/negative_sampler.py#L79

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.