When I look at how the nodes for negative sampling are generated, I see that in the dgl source code it seems to be randomly generated? I’m not sure how to exclude those nodes that have connection edges with the source node?
from pytorch/graphsage/link_pred.py
sampler = as_edge_prediction_sampler(
sampler,
exclude="reverse_id",
reverse_eids=reverse_eids,
negative_sampler=negative_sampler.Uniform(1),
)
to dgl’s negative_sampler.py
def _generate(self, g, eids, canonical_etype):
_, _, vtype = canonical_etype
shape = F.shape(eids)
dtype = F.dtype(eids)
ctx = F.context(eids)
shape = (shape[0] * self.k,)
src, _ = g.find_edges(eids, etype=canonical_etype)
src = F.repeat(src, self.k, 0)
dst = F.randint(shape, dtype, ctx, 0, g.num_nodes(vtype))
return src, dst