Set2Set model in DGL

Set2Set is a powerful global pooling operator based on iterative content-based attention ! Now one can use it in DGL as below.

Reference
Paper | Code
PyG implementation

import torch
import dgl
from torch import tensor
from torch.nn.functional import softmax
class Set2Set(torch.nn.Module):
    def __init__(self, 
                 in_channels,
                 processing_steps, 
                 num_layers=1):
        super(Set2Set, self).__init__()

        self.in_channels = in_channels
        self.out_channels = 2 * in_channels
        self.processing_steps = processing_steps
        self.num_layers = num_layers

        self.lstm = torch.nn.LSTM(self.out_channels, self.in_channels,
                                  num_layers)

        self.reset_parameters()

    def reset_parameters(self):
        self.lstm.reset_parameters()

    def forward(self, bg, feat):
        batch_size = bg.batch_size
        x = bg.ndata[feat]
        batch = tensor([], dtype = torch.int64)
        batch_num_nodes = bg.batch_num_nodes
        for index, num in enumerate(batch_num_nodes):
            batch = torch.cat((batch, tensor(index).expand(num)))
        
        h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)),
             x.new_zeros((self.num_layers, batch_size, self.in_channels)))
        q_star = x.new_zeros(batch_size, self.out_channels)

        for i in range(self.processing_steps):
            q, h = self.lstm(q_star.unsqueeze(0), h)
            q = q.view(batch_size, self.in_channels)
            e = (x * q[batch]).sum(dim=-1, keepdim=True)
            a = torch.cat(list(map(lambda x: softmax(x, dim = 0),
                                   list(torch.split(e, batch_num_nodes)))),
                          dim = 0)
            bg.ndata['w'] = a
            r = dgl.sum_nodes(bg, feat, 'w')
            q_star = torch.cat([q, r], dim=-1)

        return q_star


    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

Hi, thanks for your contribution, we had incorporate Set2Set in our Graph Pooling PR: https://github.com/dmlc/dgl/pull/669 and it would be merged soon.

There are some missing functions to make softmax/max on batched graph efficient, and our PR would add these apis so that you do not need to run softmax on each individual graph in the BatchedDGLGraph, which is inefficient.

Glad to hear that, I know this isn’t efficient, even if I’m not a professional programmer, but I’m using it as a temporary solution, and my goal isn’t to pull requests. By the way, which folder in your repository can I see your implementation?:grinning:

You could see the changes in “files changes” section of that PR.

Thanks a lot! :heart:

Why is it the case that Set2Set returns an object of the 2 * in_channels shape? According to the docs (https://docs.dgl.ai/en/latest/api/python/nn.pytorch.html#set2set) and the comments in implementation input dim and output dim should be equal.

Btw, when is the DiffPool documentation going to be available? Thanks!

Hi @Slowika, sorry for the late reply.
Let’s look the equation of set2set model


Notice that q^*_t's size is twice that of q_t's. If we want to use the same LSTMCell for all iterations, the q^*_t's size must equals q^*_{t-1}'s. This being said, the output size of LSTMCell should be 2 * input size.

For the diffpool documentation, I’ve add this to our roadmap. Thanks for your notification.