MPNN implementation

The Edge-conditioned Convolution in DGL

References:
Paper: https://arxiv.org/abs/1704.01212
Author’s code: https://github.com/brain-research/mpnn
PyG implementation: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/nn_conv.html#NNConv

import torch
import torch.nn as nn
from torch.nn import Parameter
import dgl.function as fn

class NNConvLayer(nn.Module):
    def __init__(self,
                 g,
                 in_channels,
                 out_channels,
                 edge_net,
                 aggr='add',
                 root_weight=True,
                 bias=True):
        super(NNConvLayer, self).__init__()
        
        self.g = g
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_net = edge_net
        self.aggr = aggr

        if root_weight:
            self.root = Parameter(torch.Tensor(in_channels, out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
        
    
    def reset_parameters(self):
        if self.root is not None:
            nn.init.xavier_normal_(self.root.data, gain=1.414)
        if self.bias is not None:
            self.bias.data.zero_()
        for m in edge_net.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data, gain=1.414)
                
    def message(self, edges):
        return {'m' : torch.matmul(edges.src['h'].unsqueeze(1),edges.data['w']).squeeze(1)}
    
    def reduce(self, nodes):
        if self.aggr == 'add':
            return {'aggr_out' : torch.sum(nodes.mailbox['m'], 1)}
        elif self.aggr == 'mean':
            return {'aggr_out' : torch.mean(nodes.mailbox['m'], 1)}
        else:
            raise AssertionError()
            
    def apply_node_func(self, nodes):
        aggr_out = nodes.data['aggr_out']        
        if self.root is not None:
            aggr_out = torch.mm(nodes.data['h'], self.root) + aggr_out
        
        if self.bias is not None:
            aggr_out = aggr_out + self.bias   
        
        return {'h': aggr_out}

    def forward(self, h, e):
        h = h.unsqueeze(-1) if h.dim() == 1 else h
        e = e.unsqueeze(-1) if e.dim() == 1 else e
        
        self.g.ndata['h'] = h
        self.g.edata['w'] = self.edge_net(e).view(-1, self.in_channels, self.out_channels)
        
        if self.aggr == 'add':
            g.update_all(self.message, self.reduce,self.apply_node_func)
        elif self.aggr == 'mean':
            
            g.update_all(self.message, self.reduce, self.apply_node_func)

        return self.g.ndata.pop('h')

Not very great work, looking forward to any advice from u guys!:smiley:

Awesome job! Would you like to contribute it to our repo? You could raise a PR and we will help you go through the steps. You could take a look at our GraphConv module to get an idea of how a module looks like. For example, we’d like to put g in the forward argument so people can use it on dynamic graphs; also we’d like to see builtin reduce functions as it’s faster.

I am willing to contribute it! The problem is my implementation is not professional enough and potential bugs may exits. I will improve my code as much as possible, and it would be better if someone could make constructive suggestions.:smiley:

@MilkshakeForReal don’t worry about it. We are all the same :slight_smile:. When you feel ready, just shoot a PR and we will review it.

There are two problems with builtin functions. First, the fn.u_mul_e('h', 'w', 'm') perform element-wise multiplication while I need is matrix multiplication. Second, the builtin reduce function fn.mean() dose not exit.

Code update

import torch
import torch.nn as nn
from torch.nn import Parameter
import dgl.function as fn

class NNConvLayer(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 edge_net,
                 aggr='sum',
                 root_weight=True,
                 bias=True):
        super(NNConvLayer, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_net = edge_net
        self.aggr = aggr

        if root_weight:
            self.root = Parameter(torch.Tensor(in_channels, out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
        
    
    def reset_parameters(self):
        if self.root is not None:
            nn.init.xavier_normal_(self.root.data, gain=1.414)
        if self.bias is not None:
            self.bias.data.zero_()
        for m in edge_net.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data, gain=1.414)
                
    def message(self, edges):
        return {'m' : torch.matmul(edges.src['h'].unsqueeze(1),edges.data['w']).squeeze(1)}
    
    def reduce(self, nodes):
        return {'aggr_out' : eval(f"torch.{self.aggr}(nodes.mailbox['m'], 1)")}
        
            
    def apply_node_func(self, nodes):
        aggr_out = nodes.data['aggr_out']        
        if self.root is not None:
            aggr_out = torch.mm(nodes.data['h'], self.root) + aggr_out
        
        if self.bias is not None:
            aggr_out = aggr_out + self.bias   
        
        return {'h': aggr_out}

    def forward(self, g, h, e):
        h = h.unsqueeze(-1) if h.dim() == 1 else h
        e = e.unsqueeze(-1) if e.dim() == 1 else e
        
        g.ndata['h'] = h
        g.edata['w'] = self.edge_net(e).view(-1, self.in_channels, self.out_channels)
        g.update_all(self.message, self.reduce,self.apply_node_func)
        return g.ndata.pop('h')

It seems that to raise a PR I have to write the whole example including dataset.py train.py and so on, but I don’t really have enough time to complete it. I am gonna leave this api here for the moment.:upside_down_face:

Hi @MilkshakeForReal, no worries. If you don’t mind I can try finishing the rest work and you can take a look then.

I don’t mind at all :smiley:, it couldn’t be better if you could help me out!

Hi, we have finished the implementation of MPNN in the GitHub based on your draft. Maybe you could take time to review&test the implementation?
Thank you~ :grinning:

2 Likes