Hi @BarclayII, sorry I forgot a few important details to replicate the bug.
Firstly, I was trying to run sample_neighbors
with edge-wise importance sampling by assigning a computed "prob"
to each edge of each metapath. It seems that the crash will happen only when assigning "prob"
on the edges (of the reverse metapaths) based on the src
node ids. My code is:
from dgl.dataloading import BlockSampler
class ImportanceSampler(BlockSampler):
def __init__(self, fanouts, metapaths, degree_counts: dict, edge_dir="in", return_eids=False):
super().__init__(len(fanouts), return_eids)
self.fanouts = fanouts
self.edge_dir = edge_dir
# degree_counts is a Dict[(nid, metapath, ntype), int]
self.degree_counts = defaultdict(lambda: 1.0, degree_counts) # defaults to 1.0 if node not found
self.metapaths = metapaths
def sample_frontier(self, block_id, g: dgl.DGLGraph, seed_nodes: Dict[str, torch.Tensor]):
fanouts = self.fanouts[block_id]
if self.edge_dir == "in":
sg = dgl.in_subgraph(g, seed_nodes)
elif self.edge_dir == "out":
sg = dgl.out_subgraph(g, seed_nodes)
# Assign edge prob's for each metapath
for metapath in sg.canonical_etypes:
ntype = metapath[0]
src, dst = sg.edges(etype=metapath)
# This crashes if metapath is a reverse metapath
edge_weight = src.apply_(
lambda nid: self.degree_counts[(nid, metapath, ntype)] # assign edge prob by src node degree
).to(torch.float)
# This works
edge_weight = src.detach().clone().apply_(
lambda nid: self.degree_counts[(nid, metapath, ntype)] # assign edge prob by src node degree
).to(torch.float)
sg.edges[metapath].data["prob"] = torch.sqrt(1.0 / edge_weights)
frontier = sample_neighbors(g=sg,
nodes=seed_nodes,
prob="prob",
fanout=fanouts,
edge_dir=self.edge_dir)
print("n_nodes", sum([k.size(0) for k in seed_nodes.values()]),
"fanouts", {k[1]: v for k, v in fanouts.items()} if isinstance(fanouts, dict) else fanouts,
"edges", frontier.num_edges(),
"pruned", sg.num_edges() - frontier.num_edges())
return frontier
Thanks for your help!
DGL: 0.6.0.post1 (dgl-cu110)
CUDA: cu110
PyTorch: 1.8.0, 1.7.1