I found a code before DGL v0.8, using a custom data sampler and dgl.dataloading.NodeDataLoader() for subgraph sampling, but the NodeDataLoader function has been deprecated. Then how should I modify my data sampler to meet the use of dgl.dataloading.DataLoader.
class PinSAGESampler(object):
def __init__(self, g, random_walk_length, random_walk_restart_prob, num_random_walk, num_neighbor, num_layer):
self.g = g
self.num_layer = num_layer
self.sampler = dgl.sampling.RandomWalkNeighborSampler(
g, random_walk_length, random_walk_restart_prob, num_random_walk, num_neighbor)
def sample_blocks(self, _, seed_nodes):
blocks = []
for _ in range(self.num_layer):
frontier = self.sampler(seed_nodes)
block = dgl.to_block(frontier, seed_nodes)
seed_nodes = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
...
...
sampler = PinSAGESampler(g, 3, 0.5,4,5, 3)
dataloader = dgl.dataloading.NodeDataLoader(
g,
train_nids,
sampler,
use_ddp=True,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0)
error is :
File "/home/miniconda3/envs/DGL/lib/python3.8/site-packages/dgl/dataloading/dataloader.py", line 860, in __init__
self.graph_sampler.sample, graph, self.use_uva, self.device),
AttributeError: 'PinSAGESampler' object has no attribute 'sample'