Classify graph types like Tensorflow


I am currently discovering GCNs as well as the use of DGL.
Previously, I used Tensorflow to categorize images. So with DGL, I want to do the same thing but classify different types of graphs. I then followed one of the tutorials offered in the DGL documentation ( However, I want to give a graph as input to the GCN and in return, have it give me the class name / type of the classified graph like Tensorflow does for images.

Is it possible to do this? If so, do you have an example?


We have a new graph classification tutorial at 5.4 Graph Classification — DGL 0.6.1 documentation. (but it’s in pytorch)

Basically graph classification is the same as image classification, that you can calculate a vector representation for each graph. And use a MLP layer to classify it at the end. Unfortunately we don’t have tensorflow graph classification example, but the general process is the same as node classification

for graph, label in graph_list:
  representation = model(graph, graph.ndata['feat']) # graph and feature, the same as node classification example
  logits = MLP(representation) 
  predicted_class = loss_fcn(logits, label) # same as image classification

Also DGL has some tricks to accelerate the computation process, by batching multiple graphs, which is showed in the tutorial mentioned above. However it’s not necessary and I would suggest start from simple model