How to handle a batch of graph with similar structure

I try to use dgl.HeteroGraphConv as a feature extractor in reinforcement learning.
So, when I sample replay-buffer, it will return a batch of hetero graph.
Each graph has 9000 nodes, 26000 edges, and 1900*1 features for each nodes.

These graphs are come from a source graph with add or delete some edges (less than 5%), all graph has the same feats.
I try to use dgl.batch() but I find it is too slow.
So I try to using for-loop, but I find this approach does not make efficient use of GPU resources.

What should I do? should I reduce the batch size and keep using dgl.batch()?

Did you profile and see where the bottleneck is? The bottleneck may not be in dgl.batch() itself but in other places. pyinstrument or line_profiler are two common choices.

Once you found the bottleneck we can optimize them.