Batch index for batch graph nodes

How to get the index of each node in batch graph. For example, Given a batch graph with two graphs. Each graph have two nodes. Then batch index tensor for the batch graph can be [0,0,1,1], which indicates that the first two nodes are in the first batch while the last to nodes are in the second batch.

You could get the number of nodes of each graph in a batch with batched_graph.batch_num_nodes(). The index can be then obtained with a torch.repeat_interleave call.

1 Like

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.