Hi,
Newbie of GNN and DGL here. I am trying to implement a link prediction model with an heterogeneous graph. To train the graph, I need positive and negative samples of edges. To create a ‘positive’ subgraph (pos_g
) and a ‘negative’ subgraph (neg_g
), I rely on the dgl.contrib.sampling.EdgeSampler
class. However, this class cannot work with an HeteroGraph. When i run my sampler, i get the following error :
/opt/dgl/include/dgl/packed_func_ext.h:117: Check failed: ObjectTypeChecker<TObjectRef>::Check(sptr.get()): Expected type graph.Graph but get graph.HeteroGraph
Here is the code I used to generate the error :
def edge_sampler(g, neg_sample_size, proportion_of_edges, edges=None, return_false_neg=False):
num_of_edges = 0
for etypes in g.canonical_etypes:
num_of_edges += g.number_of_edges(etypes)
sampler = dgl.contrib.sampling.EdgeSampler(g, batch_size=int(num_of_edges * proportion_of_edges),
seed_edges=edges,
neg_sample_size=neg_sample_size,
negative_mode='tail',
shuffle=True,
return_false_neg=return_false_neg)
sampler = iter(sampler)
return next(sampler)
pos_g, neg_g = edge_sampler(g, neg_sample_size, proportion_of_edges, valid_eids, return_false_neg=True)
Is there a work around to use the EdgeSampler class on a HeteroGraph? If not, do you have any recommendations on how to do edge sampling on a HeteroGraph?
Thanks in advance!