Hi AlvinLXS, below you can find a toy demo I wrote:
"""This is a demo for graph classification with dgl where we have a
synthetic dataset consisting of cycle_graphs and star_graphs and
we want to perform a binary classification."""
import dgl
import dgl.function as fn
import networkx as nx
import random
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Generate a synthetic dataset
ds = []
for i in range(5):
num_nodes = 5
cycle = dgl.DGLGraph(nx.cycle_graph(num_nodes))
cycle.ndata['h'] = th.ones(num_nodes, 1)
ds.append((cycle, 0))
star = dgl.DGLGraph(nx.star_graph(num_nodes - 1))
star.ndata['h'] = th.ones(num_nodes, 1)
ds.append((star, 1))
random.shuffle(ds)
g_list = [data[0] for data in ds]
labels = th.tensor([data[1] for data in ds]).float().view(-1, 1)
# Model and optimizer
weight1 = th.randn((1, 16), requires_grad=True)
weight2 = th.randn((16, 1), requires_grad=True)
optimizer = optim.Adam([weight1, weight2], lr=0.01)
loss_func = nn.BCELoss()
# Configure message passing. With the message func and reduce func defined
# below, the updated node feature will simply be node degree.
msg_func = fn.copy_src(src='h', out='m')
reduce_func = fn.sum(msg='m', out='h')
# Training
for i in range(100):
bg = dgl.batch(g_list)
# Perform message passing
bg.update_all(msg_func, reduce_func)
# Readout and get graph features for the 10 graphs.
bg_h = dgl.mean_nodes(bg, 'h')
logits = F.relu(th.mm(bg_h, weight1))
logits = th.mm(logits, weight2)
prediction = th.sigmoid(logits)
optimizer.zero_grad()
loss = loss_func(prediction, labels)
print('The prediction loss is {}'.format(loss))
loss.backward()
optimizer.step()
This demo is simply for demonstrating how to batch a list of graphs, how to perform message passing directly over a batch of graphs and how to perform readout and get graph features for the batch. In fact you will realize that in this toy example our classifier fails to perform a good classification in the end due to a bad feature design. Hopefully this give you a skeleton and high level idea for implementing the graph classification models you like. If you realize there is some particularly interesting model for graph classification, please let us know and we may provide an example code for that model. You are also very welcome to contribute your model code.