GAT - Batched Graph Classification Task

I am trying to rework the GAT example (https://github.com/dmlc/dgl/tree/master/examples/pytorch/gat) for graph classification rather than the node labeling tasks in that folder. I was wondering if I could replace the the graph g in all the modules to exclusively be on during the forward statements of the various submodules? Would there be any other changes I need to make?

Yeah, you can simply pass a DGLGraph object in forward. This is also how I do these things :slight_smile: .

The only thing to be careful is that many node data and edge data will be stored in the DGLGraph object during forward. No issues if you simply load new batches of graphs every time and only use them during forward. But if you want to be safe you may want to remove data stored under 'hv' with g.ndata.pop('hv') or g.edata.pop('hv') when a forward is finished.

Quick follow up on that, so the code as is will work with dgl.batch() output?

In terms of the pop commands, what exactly would i be popping i see many variables (“ft”,“a_drop”,“z”, etc) ? and I would adde these pop lines at the point were the forward returns either the residual or not?

Maruthi

  1. dgl.batch returns a BatchedDGLGraph object. Apart from being readonly, it should be the same as normal DGLGraph objects.

  2. Correct.