Yes, it’s possible. Below is an example.
import dgl.function as fn
import torch as th
from torch import nn
from torch.nn import init
from dgl.base import DGLError
from dgl.utils import expand_as_pair
class GraphConv(nn.Module):
r"""
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
bias : bool, optional
If True, adds a learnable bias to the output. Default: ``True``.
activation: callable activation function/layer or None, optional
If not None, applies an activation function to the updated node features.
Default: ``None``.
Attributes
----------
weight : torch.Tensor
The learnable weight tensor.
bias : torch.Tensor
The learnable bias tensor.
"""
def __init__(self,
in_feats,
out_feats,
bias=True,
activation=None):
super(GraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
if bias:
self.bias = nn.Parameter(th.Tensor(out_feats))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self._activation = activation
def reset_parameters(self):
"""Reinitialize learnable parameters."""
if self.weight is not None:
init.xavier_uniform_(self.weight)
if self.bias is not None:
init.zeros_(self.bias)
def forward(self, graph, feat, eweight):
r"""Compute graph convolution.
Notes
-----
* Input shape: :math:`(N, *, \text{in_feats})` where * means any number of additional
dimensions, :math:`N` is the number of nodes.
* Output shape: :math:`(N, *, \text{out_feats})` where all but the last dimension are
the same shape as the input.
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor or pair of torch.Tensor
If a torch.Tensor is given, it represents the input feature of shape
:math:`(N, D_{in})`
where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
Note that in the special case of graph convolutional networks, if a pair of
tensors is given, the latter element will not participate in computation.
eweight : torch.Tensor of shape (E, 1)
Values associated with the edges in the adjacency matrix.
Returns
-------
torch.Tensor
The output feature
"""
with graph.local_scope():
feat_src, feat_dst = expand_as_pair(feat, graph)
if self._in_feats > self._out_feats:
# mult W first to reduce the feature size for aggregation.
feat_src = th.matmul(feat_src, self.weight)
graph.srcdata['h'] = feat_src
graph.edata['w'] = eweight
graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h'))
rst = graph.dstdata['h']
else:
# aggregate first then mult W
graph.srcdata['h'] = feat_src
graph.edata['w'] = eweight
graph.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h'))
rst = graph.dstdata['h']
rst = th.matmul(rst, self.weight)
if self.bias is not None:
rst = rst + self.bias
if self._activation is not None:
rst = self._activation(rst)
return rst
def extra_repr(self):
"""Set the extra representation of the module,
which will come into effect when printing the model.
"""
summary = 'in={_in_feats}, out={_out_feats}'
summary += ', normalization={_norm}'
if '_activation' in self.__dict__:
summary += ', activation={_activation}'
return summary.format(**self.__dict__)
# The norm can be computed as follows
in_degs = g.in_degrees().float()
in_norm = th.pow(in_degs, -0.5).unsqueeze(-1)
out_degs = g.out_degrees().float()
out_norm = th.pow(out_degs, -0.5).unsqueeze(-1)
g.ndata['in_norm'] = in_norm
g.ndata['out_norm'] = out_norm
g.apply_edges(fn.u_mul_v('in_norm', 'out_norm', 'eweight'))