I was referring to the following link to implement the NegativeSampler for Heterogeneous Graphs. However, I am getting 0 edges into the negative graph.
https://docs.dgl.ai/guide/minibatch-link.html?highlight=negativesampler
Also, there was a small mistake in the documentation in the above link: It should be
train_eid_dict = {etype: g.edges(etype=etype, form='eid') for etype in g.etypes}
instead of train_eid_dict = {g.edges(etype=etype, form='eid') for etype in g.etypes}
When I made the above change and pass it to the edge data loader, It’s neither not throwing any error nor giving any negative edges.
Replicated the problem using the below sample code. Can someone debug where is it going wrong?
g = dgl.heterograph({ ('user', 'watched', 'item'): (torch.tensor([0, 1]), torch.tensor([1, 2])), ('item', 'watched-by', 'user'): (torch.tensor([0, 1]), torch.tensor([2, 3])), }) class HeteroNegativeSampler(object): def __init__(self, g, k): # caches the probability distribution self.weights = { etype: g.in_degrees(etype=etype).float() ** 0.5 for _, etype, _ in g.canonical_etypes } self.k = k def __call__(self, g, eids_dict): result_dict = {} for etype, eids in eids_dict.items(): src, _ = g.find_edges(eids, etype=etype) src = src.repeat_interleave(self.k) dst = self.weights[etype].multinomial(len(src), replacement=True) result_dict[etype] = (src, dst) return result_dict train_eid_dict = {etype: g.edges(etype=etype, form='eid') for etype in g.etypes} sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2) dataloader = dgl.dataloading.EdgeDataLoader( g, train_eid_dict, sampler, negative_sampler=HeteroNegativeSampler(g, 5), batch_size=1024, shuffle=True, drop_last=False, num_workers=4) input_nodes, pos_graph, neg_graph, bipartites = next(iter(dataloader)) print('Positive graph # nodes:', pos_graph.number_of_nodes(), '# edges:', pos_graph.number_of_edges()) print('Negative graph # nodes:', neg_graph.number_of_nodes(), '# edges:', neg_graph.number_of_edges())
Output:
Positive graph # nodes: 7 # edges: 4
Negative graph # nodes: 7 # edges: 0