Try this one.
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import NNConv, Set2Set
class MPNNGNN(nn.Module):
"""MPNN.
MPNN is introduced in `Neural Message Passing for Quantum Chemistry
<https://arxiv.org/abs/1704.01212>`__.
This class performs message passing in MPNN and returns the updated node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
node_out_feats : int
Size for the output node representations. Default to 64.
edge_in_feats : int
Size for the input edge features. Default to 128.
edge_hidden_feats : int
Size for the hidden edge representations.
num_step_message_passing : int
Number of message passing steps. Default to 6.
"""
def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64,
edge_hidden_feats=128, num_step_message_passing=6):
super(MPNNGNN, self).__init__()
self.project_node_feats = nn.Sequential(
nn.Linear(node_in_feats, node_out_feats),
nn.ReLU()
)
self.num_step_message_passing = num_step_message_passing
edge_network = nn.Sequential(
nn.Linear(edge_in_feats, edge_hidden_feats),
nn.ReLU(),
nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats)
)
self.gnn_layer = NNConv(
in_feats=node_out_feats,
out_feats=node_out_feats,
edge_func=edge_network,
aggregator_type='sum'
)
self.gru = nn.GRU(node_out_feats, node_out_feats)
def reset_parameters(self):
"""Reinitialize model parameters."""
self.project_node_feats[0].reset_parameters()
self.gnn_layer.reset_parameters()
for layer in self.gnn_layer.edge_nn:
if isinstance(layer, nn.Linear):
layer.reset_parameters()
self.gru.reset_parameters()
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
Returns
-------
node_feats : float32 tensor of shape (V, node_out_feats)
Output node representations.
"""
node_feats = self.project_node_feats(node_feats) # (V, node_out_feats)
hidden_feats = node_feats.unsqueeze(0) # (1, V, node_out_feats)
for _ in range(self.num_step_message_passing):
node_feats = F.relu(self.gnn_layer(g, node_feats, edge_feats))
node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats)
node_feats = node_feats.squeeze(0)
return node_feats
class MPNNPredictor(nn.Module):
"""MPNN for regression and classification on graphs.
MPNN is introduced in `Neural Message Passing for Quantum Chemistry
<https://arxiv.org/abs/1704.01212>`__.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_out_feats : int
Size for the output node representations. Default to 64.
edge_hidden_feats : int
Size for the hidden edge representations. Default to 128.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
num_step_message_passing : int
Number of message passing steps. Default to 6.
num_step_set2set : int
Number of set2set steps. Default to 6.
num_layer_set2set : int
Number of set2set layers. Default to 3.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_out_feats=64,
edge_hidden_feats=128,
n_tasks=1,
num_step_message_passing=6,
num_step_set2set=6,
num_layer_set2set=3):
super(MPNNPredictor, self).__init__()
self.gnn = MPNNGNN(node_in_feats=node_in_feats,
node_out_feats=node_out_feats,
edge_in_feats=edge_in_feats,
edge_hidden_feats=edge_hidden_feats,
num_step_message_passing=num_step_message_passing)
self.readout = Set2Set(input_dim=node_out_feats,
n_iters=num_step_set2set,
n_layers=num_layer_set2set)
self.predict = nn.Sequential(
nn.Linear(2 * node_out_feats, node_out_feats),
nn.ReLU(),
nn.Linear(node_out_feats, n_tasks)
)
def forward(self, g, node_feats, edge_feats):
"""Graph-level regression/soft classification.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features.
Returns
-------
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
"""
node_feats = self.gnn(g, node_feats, edge_feats)
graph_feats = self.readout(g, node_feats)
return self.predict(graph_feats)
model = MPNNPredictor(node_in_feats=2, edge_in_feats=1, n_tasks=trainset.num_classes,
num_step_message_passing=2, num_step_set2set=2, num_layer_set2set=2)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
epoch_losses = []
for epoch in range(80):
epoch_loss = 0
for iter, (bg, label) in enumerate(data_loader):
prediction = model(bg)
loss = loss_func(prediction, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (iter + 1)
print(‘Epoch {}, loss {:.4f}’.format(epoch, epoch_loss))
epoch_losses.append(epoch_loss)
For
I also wonder if I can graph mini batches here.
I don’t quite understand.