How does batch graph classification work in detail


so I would like to know how the batching for graph classification works.
I understand what is achieved by block diagonalizing the batch. But I dont understand how you apply the matrix to that block matrix.

For example I have A (adjacency matrix of size [n x n] ) and my feature matrix X [n x p] where p is the number of features. So my weight matrix W has to be size [p x f] where f is an arbitrary size.

In that sense A x X x W would give me a [n x f]-sized matrix.

So now when I want to batch multiple graphs. I create a sparse block diagonal matrix for each A and X of the graphs in my batch. I can still multiply these two matrices. However , how would then the weight matrix be applied to that sparse matrix.

For example lets say I have a batch size of three and all three graphs have the same number of nodes n.

Then the new A is of size [3n x 3n] X is then of size [3n x p] and when multiplying these two matrices. I get a matrix of size [3n x p]. And my W matrix is [pxf] The resulting matrix is then [3n x f]. how do I then get then three (for each graph one ) prediction from that output matrix

I’m not sure if I got the problem. For A x X x W, you will get [n x f] rather than [p x f]. For [3n x 3n] x [3n x p] x [p x f], you will get [3n x f]. You can record the graph membership for nodes and then unbatch them to get the node features for each graph.

Sorry I mad a mistake on the dimensions of A x X x W. I changed it to [n x f]. But if I just unbatch the will pytorch be able to backrpopagate the error. Somewhere I read that you could use a pooling matrix to seperate the graphs within a batch

Edit: Maybe I am wrong but if you use a pooling matrix of size [3n x 3] you would get the sum of the feature per Graph in the batch? I am not sure how it is not in the package.

For node and edge features, batching just means concatenation and PyTorch will be able to BP the error.

By pooling, I assume you mean computing graph-level representations out of node and edge representations. We’ve implemented some common methods for this purpose, see Graph Readout section.