Access batch data by index

Hi, when I create a batched graph, is there a way to sample the subgraphs in the batch?

import dgl
import torch as th
# 4 nodes, 3 edges
g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
# 3 nodes, 4 edges
g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
bg = dgl.batch([g1, g2])
bg
bg.batch_size
bg.batch_num_nodes()
bg.batch_num_edges()
bg.edges()

something like:

subgraph_0 = bg.batch[0]
subgraph_1 = bg.batch[1]
print(subgraph_0 == g1)
>>> True
print(subgraph_1 == g2)
>>> True

thanks!

We recently added a slice_batch API, which does exactly what you described here. To use this API, you can either install the nightly-built version if you are using Linux or you can install from source.

1 Like