Precompute the normalized adjacency matrix for Graph Convolution

Hi,

I notice from the source code of GraphConv that if the ‘norm’ option is not None, then the normalization operation of the adjacency matrix will be computed every time the forward() method is called. I am wondering whether this would cause a waste of time :thinking:.

If so, is there any approach to precompute the normalized adjacency matrix and just skip the corresponding computation in the forward() method?

Thank you very much!

Yes, it’s possible. Below is an example.

import dgl.function as fn
import torch as th
from torch import nn
from torch.nn import init

from dgl.base import DGLError
from dgl.utils import expand_as_pair

class GraphConv(nn.Module):
    r"""
    Parameters
    ----------
    in_feats : int
        Input feature size.
    out_feats : int
        Output feature size.
    bias : bool, optional
        If True, adds a learnable bias to the output. Default: ``True``.
    activation: callable activation function/layer or None, optional
        If not None, applies an activation function to the updated node features.
        Default: ``None``.

    Attributes
    ----------
    weight : torch.Tensor
        The learnable weight tensor.
    bias : torch.Tensor
        The learnable bias tensor.
    """
    def __init__(self,
                 in_feats,
                 out_feats,
                 bias=True,
                 activation=None):
        super(GraphConv, self).__init__()

        self._in_feats = in_feats
        self._out_feats = out_feats
        self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))

        if bias:
            self.bias = nn.Parameter(th.Tensor(out_feats))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

        self._activation = activation

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        if self.weight is not None:
            init.xavier_uniform_(self.weight)
        if self.bias is not None:
            init.zeros_(self.bias)

    def forward(self, graph, feat, eweight):
        r"""Compute graph convolution.

        Notes
        -----
        * Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional
          dimensions, :math:`N` is the number of nodes.
        * Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
          the same shape as the input.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, it represents the input feature of shape
            :math:`(N, D_{in})`
            where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
            Note that in the special case of graph convolutional networks, if a pair of
            tensors is given, the latter element will not participate in computation.
        eweight : torch.Tensor of shape (E, 1)
            Values associated with the edges in the adjacency matrix.

        Returns
        -------
        torch.Tensor
            The output feature
        """
        with graph.local_scope():
            feat_src, feat_dst = expand_as_pair(feat, graph)

            if self._in_feats > self._out_feats:
                # mult W first to reduce the feature size for aggregation.
                feat_src = th.matmul(feat_src, self.weight)
                graph.srcdata['h'] = feat_src
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'),
                                            fn.sum('m', 'h'))
                rst = graph.dstdata['h']
            else:
                # aggregate first then mult W
                graph.srcdata['h'] = feat_src
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'),
                                            fn.sum('m', 'h'))
                rst = graph.dstdata['h']
                rst = th.matmul(rst, self.weight)

            if self.bias is not None:
                rst = rst + self.bias

            if self._activation is not None:
                rst = self._activation(rst)

            return rst

    def extra_repr(self):
        """Set the extra representation of the module,
        which will come into effect when printing the model.
        """
        summary = 'in={_in_feats}, out={_out_feats}'
        summary += ', normalization={_norm}'
        if '_activation' in self.__dict__:
            summary += ', activation={_activation}'
        return summary.format(**self.__dict__)

# The norm can be computed as follows
in_degs = g.in_degrees().float()
in_norm = th.pow(in_degs, -0.5).unsqueeze(-1)
out_degs = g.out_degrees().float()
out_norm = th.pow(out_degs, -0.5).unsqueeze(-1)
g.ndata['in_norm'] = in_norm
g.ndata['out_norm'] = out_norm
g.apply_edges(fn.u_mul_v('in_norm', 'out_norm', 'eweight'))

Hi Mufei,

Thank you very much for your help! :clap: