How to classify entire graphs in batches?



I would like to use DGL to classify graphs that represent the structure of diagrams in science textbooks available in the AI2D Dataset.

Note that each diagram belongs to one of 17 categories, and I would like to classify entire diagrams.

I’m new to graph convolutional networks, but so far I have managed to create a PyTorch DataLoader that provides me with 1. A BatchedDGLGraph object with graphs for each diagram in the batch and 2. A list of integers representing the respective classes of each diagram in the batch.

For features, each node contains the coordinates of its bounding box in the diagram.

How should I modify the vanilla GCN described in the tutorial to classify entire diagrams and process them in batches?

Thanks in advance for help & many thanks for a great library! If I get this to work, perhaps it could make an interesting tutorial for the documentation?


A good reference Relational inductive biases, deep learning, and graph networks. Basically, to classify a whole graph, you need to summarize the node and edge hidden states (e.g. averaging, max pooling, etc) to yield the global state of the graph.



It’s a really interesting task !! Definitely worth a tutorial or blog (we could discuss how to go through the process later).

You’ve already had a good start in creating a data loader that produces BatchedDGLGraph. You can treat the BatchedDGLGraph as a single graph with multiple connected components (each is one diagram in the batch), so running GCN on the BatchedDGLGraph is the same as running GCN on each diagram in parallel. GCN will produce some embeddings (or representations) for each node. To classify the entire graph, you need to convert them to graph (or global) representation. There are a couple of ways to do so:

  1. Use our provided readout functions for BatchedDGLGraph. For example, sum_nodes generates graph embedding by summing up node embeddings. The first dimension in the returned tensor is the same as the number of graphs in the batch. (see the example in the API doc).
  2. If you are using a more complicated readout function that are not covered by our provided routines, such as the Set2Set model or the function described in the DGMG model, you could unbatch the BatchedDGLGraph, which will return a list of DGLGraph. You can then get the node embeddings of each graph and do whatever you want.

After you get the graph embeddings (try to make it as a (B, D) shape tensor, where B is the batch size and D is the embedding size), you could cast it to an embedding of size 17 with an MLP, and throw a cross entropy loss with the labels and let it train.

This is how I will get started on this task. Of course there will be some tuning and tweaking. These should be similar to other image classification tasks.

Hope this help and look forward to your progress.


This is really an interesting task. Another potential way for readout is to consider graph clustering structures and compute graph embeddings in a hierarchical fashion. Please read this paper for more details.


Hey everyone,

thanks for suggestions and advice! I think I’ll start by testing the readout functions before any more complex approaches.

My next step would be to add more features to the nodes and edges (although accounting for edges would require a different net than vanilla GCN, right?), such as word vectors for text and extracting image features using a pre-trained CNN.

In general, I think that for this kind of task – representing the structure and semantics of diagrams – one should attempt to preserve as much of this information as possible, so I’m open to all kinds of suggestions.

Thanks again & I’ll keep you posted on the updates!


Maybe there also need a max_node function in Graph Readout, since the overall max-pooling is quite popular in CNNs in NLP tasks, and it may also useful in GCNs while dealing sentence graphs such as dependency trees.


It’s in our roadmap. However, currently we are lack of operator support in backend (such as segment max in this case).
At current stage, you can unbatch the graph and do max readout on each of them manually.


I have just started to contact the graph convolutional neural network and i am a little confused about how to use BatchedDGLGraph. So I want to know is there any demo codes for classify tasks with batched samples and BatchedDGLGraph. Thanks in advance.


Probably we should provide a model example for graph classification?


Thanks for your reply and the demo will do help.


It’s better to demo the usage with a toy example or application. Any suggestions?


I can try to give a toy example where we perform a binary classification between cycle_graph and star_graph for example.


Hi AlvinLXS, below you can find a toy demo I wrote:

"""This is a demo for graph classification with dgl where we have a
synthetic dataset consisting of cycle_graphs and star_graphs and
we want to perform a binary classification."""
import dgl
import dgl.function as fn
import networkx as nx
import random
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Generate a synthetic dataset
ds = []
for i in range(5):
    num_nodes = 5
    cycle = dgl.DGLGraph(nx.cycle_graph(num_nodes))
    cycle.ndata['h'] = th.ones(num_nodes, 1)
    ds.append((cycle, 0))

    star = dgl.DGLGraph(nx.star_graph(num_nodes - 1))
    star.ndata['h'] = th.ones(num_nodes, 1)
    ds.append((star, 1))

g_list = [data[0] for data in ds]
labels = th.tensor([data[1] for data in ds]).float().view(-1, 1)

# Model and optimizer
weight1 = th.randn((1, 16), requires_grad=True)
weight2 = th.randn((16, 1), requires_grad=True)
optimizer = optim.Adam([weight1, weight2], lr=0.01)
loss_func = nn.BCELoss()

# Configure message passing. With the message func and reduce func defined
# below, the updated node feature will simply be node degree.
msg_func = fn.copy_src(src='h', out='m')
reduce_func = fn.sum(msg='m', out='h')

# Training
for i in range(100):
    bg = dgl.batch(g_list)
    # Perform message passing
    bg.update_all(msg_func, reduce_func)
    # Readout and get graph features for the 10 graphs.
    bg_h = dgl.mean_nodes(bg, 'h')
    logits = F.relu(, weight1))
    logits =, weight2)
    prediction = th.sigmoid(logits)

    loss = loss_func(prediction, labels)
    print('The prediction loss is {}'.format(loss))

This demo is simply for demonstrating how to batch a list of graphs, how to perform message passing directly over a batch of graphs and how to perform readout and get graph features for the batch. In fact you will realize that in this toy example our classifier fails to perform a good classification in the end due to a bad feature design. Hopefully this give you a skeleton and high level idea for implementing the graph classification models you like. If you realize there is some particularly interesting model for graph classification, please let us know and we may provide an example code for that model. You are also very welcome to contribute your model code.


Thanks for your demo, i will learn how to use it as soon as possible. Thanks again.