Hi experts,
What is the maximum number of nodes in a batch that can be sampled from dgl.dataloading.MultiLayerNeighborSampler
? Consider the following example for contextualizing this question:
import dgl
import numpy as np
dataset = dgl.data.RedditDataset()
graph = dataset[0]
adj_sparse = graph.adj(scipy_fmt='coo')
train_ids = np.arange(adj_sparse.shape[0])[graph.ndata['train_mask']]
graph = dgl.graph((adj_sparse.row, adj_sparse.col))
dataset = None
fanouts = [5,5]
sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
dataloader = dgl.dataloading.DataLoader(
graph, train_ids, sampler,
batch_size=64,
shuffle=False,
drop_last=False,
num_workers=0)
dataloader_iter = iter(dataloader)
input_nodes, output_nodes, mfgs = next(dataloader_iter)
print(len(input_nodes)) # 2122
I expect that the number of nodes that we make predictions for is 64 (setting of batch_size
parameter), the total number of one-hop neighbours is \leq 64 \times 5 (setting of fanouts[0]
), and the total number of two-hop neighbours is \leq 5 \times (64 \times 5). Altogether, then the total number of nodes we’re passing into the GNN would be \leq 64 + 64 \times 5 + 64 \times 5 \times 5=1,984. However, I obtain a batch that has 2,122 nodes, exceeding what I thought should be the upper bound. Could someone tell me what the correct upper bound should be as a function of the batch size and the fanout
array parameter (and possibly any other relevant parameters that I may have failed to consider)?
Thanks in advance!