Is there built-in function to do edges update?


I am a starter of DGL and I want to build a MPNN with edge update as in this paper
In short, I want to update the edges between i and j by
$$ e_{ij} = E_t(h_i, h_j, e_{ij}) $$, where E_t is the edge update function. I searched the source code, and found one possibility is group_apply_edges(). Could anyone please give me some suggestions on this?


Let me show you an example.

If E_ij=h_i+h_j+e_ij, then we can define a function to compute the edge date with

import torch
from dgl import DGLGraph

def f_edge(edges):
    return {'E': edges.src['h'] + edges.dst['h'] +['e']}

g = DGLGraph()
g.add_edges([0, 1], [1, 2])
g.ndata['h'] = torch.tensor([[0.], [1.], [2.]])
g.edata['e'] = torch.tensor([[3.], [4.]])


You can find the API reference of apply_edges here.

Also let me know if you are interested in some particular E_t function.

If you want to use builtin functions,

import torch
import dgl.function as fn

g = ... # some DGLGraph
g.ndata['h'] = ...
g.edata['e'] = ...
g.apply_edges(fn.u_add_v('h', 'h', 'hh'))
g.edata['E'] = g.edata['e'] + g.edata['hh']

Thanks, Mufei.

If you have time, it is appreiated that you could take a look of my implementation of Message Passing with Edges Update. I know it is quite impolite to request you to debug for me. Thus if you do not have time, please forget this.

The edge update function I want to use is something like below

However, when I added this into the MPNN model of @geekinglcq , it does not work.

#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from import dataloader
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dgl_nn

import os
from os import listdir
from os.path import isfile, join
import networkx as nx
import numpy as np
import pickle

class NNConvLayer(nn.Module):
    """MPEU Conv Layer



    def __init__(self,
        super(NNConvLayer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_net = edge_net
        self.edge_lin =edge_lin

        if root_weight:
            self.root = Parameter(torch.Tensor(in_channels, out_channels))
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            self.register_parameter('bias', None)

    def reset_parameters(self):
        if self.root is not None:
            nn.init.xavier_normal_(, gain=1.414)
        if self.bias is not None:
        for m in self.edge_net.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(, gain=1.414)

    def message(self, edges):
        return {'m': torch.matmul(edges.src['h'].unsqueeze(1),['w']).squeeze(1)}
    def apply_node_func(self, nodes):
        aggr_out =['aggr_out']
        if self.root is not None:
            aggr_out =['h'], self.root) + aggr_out

        if self.bias is not None:
            aggr_out = aggr_out + self.bias

        return {'h': aggr_out}

    def update_edge(self, edges):
        cedge =['h'], edges.dst['h'],['e_feat']), -1)
        #cedge = cedge.unsqueeze(-1) if cedge.dim() == 1 else cedge
        cedge = self.edge_net(cedge)
        wedge = cedge.view(-1, self.in_channels, self.out_channels)
        nedge = self.edge_lin(cedge)
        return {"w": wedge, "e_feat":nedge}

    def forward(self, g, h, e):
        h = h.unsqueeze(-1) if h.dim() == 1 else h
        e = e.unsqueeze(-1) if e.dim() == 1 else e

        g.ndata['h'] = h
        g.update_all(self.message, fn.sum("m", "aggr_out"), self.apply_node_func)
        return g.ndata.pop('h')

class MPEUModel(nn.Module):
    """MPEU model"""

    def __init__(self,
        """model parameters setting


        super(MPEUModel, self).__init__() = "MPEU"
        self.num_step_message_passing = num_step_message_passing
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        self.device = device
        edge_network = nn.Sequential(
            nn.Linear(int(edge_input_dim+2*node_hidden_dim), edge_hidden_dim),
            nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim)
        edge_lin = nn.Linear(node_hidden_dim * node_hidden_dim, edge_input_dim)

        self.conv = NNConvLayer(in_channels=node_hidden_dim, out_channels=node_hidden_dim, edge_net=edge_network,edge_lin = edge_lin,
        self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)

        self.set2set = dgl_nn.glob.Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set)
        self.lin1 = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
        self.lin2 = nn.Linear(node_hidden_dim, output_dim)

    def forward(self, g):
        h = g.ndata['n_feat']
        h =
        out = F.relu(self.lin0(h))
        h = out.unsqueeze(0)
        for i in range(self.num_step_message_passing):
            h = self.conv(g, out, g.edata['e_feat'])
            m = F.relu(h)
            out, h = self.gru(m.unsqueeze(0), h.unsqueeze(0))
            out = out.squeeze(0)

        out = self.set2set(g, out)
        out = F.relu(self.lin1(out))
        out = self.lin2(out)
        return out

I felt like this should be some problem of the dgl graph size during forwarding, but I have no idea which one currently. Thanks in advance.


Iā€™m willing to help you debug. Could you please highlight the lines you modified, the place where errors occur and the error messages?

Also it seems that there is an extra h.unsqueeze(0) at out, h = self.gru(m.unsqueeze(0), h.unsqueeze(0)) comparing with the original implementation.

Thanks, Mufei.

Yes, I changed the shape of h because the error message said the shape does not match. Sorry I am in traveling currently and cannot access to the error message in my university computer. I will let you know the error message once I can access to it.

Thank you very much.

No worries. Enjoy :grinning:

Thanks, Mufei.
I have followed your suggestion and checked the code again. Now the code works. The floowing scritpt is the model of Neural Message Passing with Edge Updates (NMPEU) proposed in This model is based on the SchNet in the model zoo and only has small differences (edges are updated) compared with the SchNet.
Please feel free to add it to the model zoo if it is suitable for that. Thanks.

# -*- coding:utf-8 -*-

import dgl
import torch as th
import torch.nn as nn
import numpy as np
import dgl.function as fn
from torch.nn import Softplus
from layers import AtomEmbedding,  ShiftSoftplus, RBFLayer

class CFConv(nn.Module):
    The continuous-filter convolution layer in Neural Message Passing with Edge Updates. 

    def __init__(self, dim=64, act="sp"):
            rbf_dim: the dimsion of the RBF layer
            dim: the dimension of linear layers
            act: activation function (default shifted softplus)
        self._dim = dim
        self.linear_layer1 = nn.Linear(int(self._dim*3), int(self._dim*2))
        self.linear_layer2 = nn.Linear(int(self._dim*2), self._dim)
        self.linear_layer3 = nn.Linear(self._dim, self._dim)
        self.linear_layer4 = nn.Linear(self._dim, self._dim)

        if act == "sp":
            self.activation = nn.Softplus(beta=0.5, threshold=14)
            self.activation = act

    def update_edge(self, edges):
        rbf =['node'], edges.dst['node'],['rbf']), -1)
        rbf = self.linear_layer1(rbf)
        rbf = self.activation(rbf)
        rbf = self.linear_layer2(rbf)
        h = self.linear_layer3(rbf)
        h = self.activation(h)
        h = self.linear_layer4(h)
        h = self.activation(h)
        return {"h": h, "rbf": rbf}

    def forward(self, g):
        g.update_all(message_func=fn.u_mul_e('new_node', 'h', 'neighbor_info'),
                     reduce_func=fn.sum('neighbor_info', 'new_node'))
        return g.ndata["new_node"], g.edata["rbf"]

class NMPEUInteraction(nn.Module):
    The interaction layer in the Neural Message Passing with Edge Updates model.

    def __init__(self, rbf_dim, dim):
        self._node_dim = dim
        self.activation = nn.Softplus(beta=0.5, threshold=14)
        self.node_layer1 = nn.Linear(dim, dim, bias=False)
        self.cfconv = CFConv(dim, act=self.activation)
        self.node_layer2 = nn.Linear(dim, dim)
        self.node_layer3 = nn.Linear(dim, dim)

    def forward(self, g):

        g.ndata["new_node"] = self.node_layer1(g.ndata["node"])
        cf_node, cf_edge = self.cfconv(g)
        cf_node_1 = self.node_layer2(cf_node)
        cf_node_1a = self.activation(cf_node_1)
        new_node = self.node_layer3(cf_node_1a)
        g.ndata["node"] = g.ndata["node"] + new_node
        g.edata["rbf"] = g.edata["rbf"]
        return g

class NMPEUModel(nn.Module):
    NMPEU Model from:

    def __init__(self,
            dim: dimension of features
            output_dim: dimension of prediction
            cutoff: radius cutoff
            width: width in the RBF function
            n_conv: number of interaction layers
            atom_ref: used as the initial value of atom embeddings,
                      or set to None with random initialization
            norm: normalization
        super().__init__() = "NMPEU"
        self._dim = dim
        self.cutoff = cutoff
        self.width = width
        self.n_conv = n_conv
        self.atom_ref = atom_ref
        self.norm = norm
        self.activation = ShiftSoftplus()

        if atom_ref is not None:
            self.e0 = AtomEmbedding(1, pre_train=atom_ref)
        if pre_train is None:
            self.embedding_layer = AtomEmbedding(dim)
            self.embedding_layer = AtomEmbedding(pre_train=pre_train)
        self.rbf_layer = RBFLayer(0, cutoff, width)
        self.edge_linear1 = nn.Linear(self.rbf_layer._fan_out, dim)
        self.conv_layers = nn.ModuleList(
            [NMPEUInteraction(self.rbf_layer._fan_out, dim) for i in range(n_conv)])

        self.atom_dense_layer1 = nn.Linear(dim, 64)
        self.atom_dense_layer2 = nn.Linear(64, output_dim)

    def set_mean_std(self, mean, std, device="cpu"):
        self.mean_per_atom = th.tensor(mean, device=device)
        self.std_per_atom = th.tensor(std, device=device)

    def update_edge(self, edges):
        rbf =["rbf"]
        rbf = self.edge_linear1(rbf)
        rbf = self.activation(rbf)
        return {"rbf": rbf}

    def forward(self, g):
        """g is the DGL.graph"""

        if self.atom_ref is not None:
            self.e0(g, "e0")
        g.apply_edges(self.update_edge) # Update the edge
        for idx in range(self.n_conv):

        atom = self.atom_dense_layer1(g.ndata["node"])
        atom = self.activation(atom)
        res = self.atom_dense_layer2(atom)
        g.ndata["res"] = res

        if self.atom_ref is not None:
            g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]

        if self.norm:
            g.ndata["res"] = g.ndata[
                "res"] * self.std_per_atom + self.mean_per_atom
        res = dgl.sum_nodes(g, "res")
        return res

if __name__ == "__main__":
    g = dgl.DGLGraph()
    g.add_edges([0, 0, 1, 1], [1, 0, 1, 0])
    g.edata["distance"] = th.tensor([1.0, 3.0, 2.0, 4.0]).reshape(-1, 1)
    g.ndata["node_type"] = th.LongTensor([1, 2])
    model = NMPEUModel(dim=2)
    atom = model(g)
1 Like


We are glad to add the model to our module/model zoo. Could you open a PR for this so we could preserve the credit to you? You can add new file at

Excellent work!!:smiley: