@BarclayII For now, The below is a hack working for me. What do you think about this?
- This union might be an additional overhead in every iteration
- I might need to move the tensors back to the CPU to make this union
The ideal way to deal with this is within the DataLoader. But, do you think we can optimize in any other way?
import numpy as np
import torch
import dgl
from dgl.dataloading import BlockSampler
class CustomNeighborSampler(BlockSampler):
def __init__(self, fanouts, exclude_neg_eids=None, edge_dir='in', prob=None, replace=False,
prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None,
output_device=None):
super().__init__(prefetch_node_feats=prefetch_node_feats,
prefetch_labels=prefetch_labels,
prefetch_edge_feats=prefetch_edge_feats,
output_device=output_device)
self.fanouts = fanouts
self.exclude_neg_eids = exclude_neg_eids
self.edge_dir = edge_dir
self.prob = prob
self.replace = replace
def sample_blocks(self, g, seed_nodes, exclude_eids=None):
exclude_eids = {
edge_type: torch.from_numpy(
np.union1d(self.exclude_neg_eids[edge_type], exclude_eids[edge_type].cpu())
).to(torch.int32).to(torch.device('cuda'))
for edge_type in self.exclude_neg_eids.keys()
}
output_nodes = seed_nodes
blocks = []
for fanout in reversed(self.fanouts):
frontier = g.sample_neighbors(
seed_nodes, fanout, edge_dir=self.edge_dir, prob=self.prob,
replace=self.replace, output_device=self.output_device,
exclude_edges=exclude_eids)
eid = frontier.edata[dgl.EID]
block = dgl.to_block(frontier, seed_nodes)
block.edata[dgl.EID] = eid
seed_nodes = block.srcdata[dgl.NID]
blocks.insert(0, block)
return seed_nodes, output_nodes, blocks