What is the maximum number of nodes in a batch that can be sampled from
dgl.dataloading.MultiLayerNeighborSampler? Consider the following example for contextualizing this question:
import dgl import numpy as np dataset = dgl.data.RedditDataset() graph = dataset adj_sparse = graph.adj(scipy_fmt='coo') train_ids = np.arange(adj_sparse.shape)[graph.ndata['train_mask']] graph = dgl.graph((adj_sparse.row, adj_sparse.col)) dataset = None fanouts = [5,5] sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts) dataloader = dgl.dataloading.DataLoader( graph, train_ids, sampler, batch_size=64, shuffle=False, drop_last=False, num_workers=0) dataloader_iter = iter(dataloader) input_nodes, output_nodes, mfgs = next(dataloader_iter) print(len(input_nodes)) # 2122
I expect that the number of nodes that we make predictions for is 64 (setting of
batch_size parameter), the total number of one-hop neighbours is \leq 64 \times 5 (setting of
fanouts), and the total number of two-hop neighbours is \leq 5 \times (64 \times 5). Altogether, then the total number of nodes we’re passing into the GNN would be \leq 64 + 64 \times 5 + 64 \times 5 \times 5=1,984. However, I obtain a batch that has 2,122 nodes, exceeding what I thought should be the upper bound. Could someone tell me what the correct upper bound should be as a function of the batch size and the
fanout array parameter (and possibly any other relevant parameters that I may have failed to consider)?
Thanks in advance!