How to properly write a custom data sampler

I found a code before DGL v0.8, using a custom data sampler and dgl.dataloading.NodeDataLoader() for subgraph sampling, but the NodeDataLoader function has been deprecated. Then how should I modify my data sampler to meet the use of dgl.dataloading.DataLoader.

class PinSAGESampler(object):
    def __init__(self, g, random_walk_length, random_walk_restart_prob, num_random_walk, num_neighbor, num_layer):
        self.g = g
        self.num_layer = num_layer
        self.sampler = dgl.sampling.RandomWalkNeighborSampler(
            g, random_walk_length, random_walk_restart_prob, num_random_walk, num_neighbor)

    def sample_blocks(self, _, seed_nodes):
        blocks = []
        for _ in range(self.num_layer):
            frontier = self.sampler(seed_nodes)
            block = dgl.to_block(frontier, seed_nodes)
            seed_nodes = block.srcdata[dgl.NID]
            blocks.insert(0, block)
        return blocks
...
...

    sampler = PinSAGESampler(g, 3, 0.5,4,5, 3)
    dataloader = dgl.dataloading.NodeDataLoader(
        g,
        train_nids,
        sampler,
        use_ddp=True,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0)

error is :

  File "/home/miniconda3/envs/DGL/lib/python3.8/site-packages/dgl/dataloading/dataloader.py", line 860, in __init__
    self.graph_sampler.sample, graph, self.use_uva, self.device),
AttributeError: 'PinSAGESampler' object has no attribute 'sample'

Hi, @BearBiscuit05, just inherit dgl.dataloading.BlockSampler instead of Object to solve the problem.

hi, thank you very much for your suggestion, this bug seems to be resolved, but after modification, I still have this error, I observed that the blocks building seems to have no errors, I still hope to get your help.

class PinSAGESampler(dgl.dataloading.BlockSampler):
    def __init__(self, g, random_walk_length, random_walk_restart_prob, num_random_walk, num_neighbor, num_layer):
        self.g = g
        self.num_layer = num_layer
        self.sampler = dgl.sampling.RandomWalkNeighborSampler(
            g, random_walk_length, random_walk_restart_prob, num_random_walk, num_neighbor)

    def sample_blocks(self, _, seed_nodes,exclude_eids=None):
        blocks = []
        for _ in range(self.num_layer):
            frontier = self.sampler(seed_nodes)
            block = dgl.to_block(frontier, seed_nodes)
            seed_nodes = block.srcdata[dgl.NID]
            blocks.insert(0, block)
        return blocks

the error is

Original Traceback (most recent call last):
  File "/home/bear/miniconda3/envs/DGL/lib/python3.8/site-packages/dgl/dataloading/dataloader.py", line 377, in _prefetcher_entry
    batch = next(dataloader_it)
  File "/home/bear/miniconda3/envs/DGL/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 652, in __next__
    data = self._next_data()
  File "/home/bear/miniconda3/envs/DGL/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 692, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/bear/miniconda3/envs/DGL/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 40, in fetch
    return self.collate_fn(data)
  File "/home/bear/miniconda3/envs/DGL/lib/python3.8/site-packages/dgl/dataloading/dataloader.py", line 537, in __call__
    batch = self.sample_func(self.g, items)
  File "/home/bear/miniconda3/envs/DGL/lib/python3.8/site-packages/dgl/dataloading/base.py", line 244, in sample
    return self.assign_lazy_features(result)
  File "/home/bear/miniconda3/envs/DGL/lib/python3.8/site-packages/dgl/dataloading/base.py", line 235, in assign_lazy_features
    set_src_lazy_features(blocks[0], self.prefetch_node_feats)
  File "/home/bear/miniconda3/envs/DGL/lib/python3.8/site-packages/dgl/heterograph.py", line 2232, in __getitem__
    raise DGLError('Invalid key "{}". Must be one of the edge types.'.format(orig_key))
dgl._ffi.base.DGLError: Invalid key "0". Must be one of the edge types.

Hi, this is because sample_blocks requires return a triplet input_nodes, output_nodes, blocks, and another missing place is super initialization, try this code:

class PinSAGESampler(dgl.dataloading.BlockSampler):
    def __init__(self, g, random_walk_length, random_walk_restart_prob, num_random_walk, num_neighbor, num_layer):
        super().__init__()
        self.g = g
        self.num_layer = num_layer
        self.sampler = dgl.sampling.RandomWalkNeighborSampler(
            g, random_walk_length, random_walk_restart_prob, num_random_walk, num_neighbor)

    def sample_blocks(self, _, seed_nodes, exclude_eids=None):
        output_nodes = seed_nodes
        blocks = []
        for _ in range(self.num_layer):
            frontier = self.sampler(seed_nodes)
            block = dgl.to_block(frontier, seed_nodes)
            seed_nodes = block.srcdata[dgl.NID]
            blocks.insert(0, block)
        return seed_nodes, output_nodes, blocks

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.