Question about batch() and unbatch()

I am doing graph-level classification and use dgl.batch() in collate_fn() (in my dataloader) to stack chunk of graphs for batching. As far as I know, the batch() function creates one big graph containing disjoint components (graphs) and the computation is performed separately but in parallel fashion . When I perform classification task, I get 1 output instead of batch_size outputs. One possible reason is the computation is performed on the big graph created by batch() instead of disjoint graphs. Do I have to unbatch() in the forward path when dealing with this kind of tasks? I use this tutorial for my implementation batching, thank you!

I have solved it, the reason is in my customized readout layer. I have to use unbatch() in my readout layer.

1 Like