I am doing batch training using the information provide in the tutorial
i.e.:
def collate(samples): # The input `samples` is a list of pairs # (graph, label). graphs, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) return batched_graph, torch.tensor(labels)
data_loader = DataLoader(trainset, batch_size=batch_size, collate_fn=helper_functions.collate,drop_last=False,shuffle=True)
And I would like to know if there is a simple way to know then index from the initial train vector from where each of the nodes in the mini batches come from. Thanks in advance!