Knowing the index when using minibatches of graphs

I am doing batch training using the information provide in the tutorial

i.e.:

 def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)
data_loader = DataLoader(trainset, batch_size=batch_size, 
                             collate_fn=helper_functions.collate,drop_last=False,shuffle=True)

And I would like to know if there is a simple way to know then index from the initial train vector from where each of the nodes in the mini batches come from. Thanks in advance!

What do you mean by “initial train vector”? Does batch_num_nodes help?

Thanks! This is exactly what I needed (to know which nodes of the batch belong to the graphs I input). :smiley:

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.