Question about using GRU in batch graph node features

Hi,

I have a question about the GRU used in /chem/mpnn.py. Suppose the batch graph contains two dglgraph. The first dglgrpha g1 has 4 nodes, and the second one g2 has 5 nodes The node features should be feed into the GRU independently. That is g1[new] = self.gru(g1, g1_hidden), g2[new] =self.gru(g2, g2_hidden). However, in the current code, this is [g1, g2] = self.gru([g1, g2], [g1, g2]_hidden) because of the batch graph. This is strange because GRU is sequence related, but there is no relationship between the last node in g1 and the first node in g2. Is my understanding of this correct? If yes, could you please give me some suggestions about how to solve this? Thanks.

Sorry I think I understand it now. The answer is the direction of GRU used here is forwarding from up to down layers, instead of forwarding from first node to last node. Thanks.

1 Like

Sorry to bother again. I realised that I still have questions here. Is there any built-in function to run a sequence based layer on the subgraph? Suppose g is a batch graph of [g1, g2], where g1 is [n1, n2, n3], g2 is [n4, n5, n6, n7]. I want to build a Transformer layer by the sequence of the subgraph (e.g, h1, h2, h3 = Transformer([n1, n2, n3]). The example of Transformer https://docs.dgl.ai/tutorials/models/4_old_wines/7_transformer.html#sphx-glr-tutorials-models-4-old-wines-7-transformer-py is based on a complete graph (each node has edges with every node in that graph). Is there any methods for non-complete graph? Thanks a lot.

Hi,

Could you elaborate a bit more? Why whether graphs are complete will make a difference?

Hi Mufei,

I will try to explain what I want to do. These code may not work, just for explanation.
The following is what I want to do with a single graph.

import dgl
import torch as th

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

x = th.randn(3, 10)
g.ndata['x'] = x
q = nn.Linear(10, 20)(g.ndata['x'])
k = nn.Linear(10, 20)(g.ndata['x'])
v = nn.Linear(10, 20)(g.ndata['x'])
output = attention(q, k, v)

However, this will not works with dgl batch graph.

x1 = th.randn(3, 10)
x2 = th.randn(2, 10)
g1.ndata['x'] = x1
g2.ndata['x'] = x2
g = batchgraph([g1, g2])

if we do something similar with above

q = nn.Linear(10, 20)(g.ndata['x'])
k = nn.Linear(10, 20)(g.ndata['x'])
v = nn.Linear(10, 20)(g.ndata['x'])
output = attention(q, k, v)

In the attention layer, q[0] will match with k[3] (query of the g1 will match with the key in the g2), this is what need to be avoid.

The current solution in dgl transformer tutorial.

x1 = th.randn(3, 10)
x2 = th.randn(2, 10)
g1.ndata['x'] = x1
g1.add_edges([0, 1, 2], [0, 1, 2])
g2.ndata['x'] = x2
g2.add_edges([0, 1], [0, 1])
g = batchgraph([g1, g2])
g.ndata['q'] = nn.Linear(10, 20)(g.ndata['x'])
g.ndata['k'] = nn.Linear(10, 20)(g.ndata['x'])
g.ndata['v'] = nn.Linear(10, 20)(g.ndata['x'])
def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: th.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func
	
def propagate_attention(self, g, eids):
    # Compute attention score
    g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
    g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)))
    # Update node state
    g.send_and_recv(eids,
                    [fn.src_mul_edge('v', 'score', 'v'), fn.copy_edge('score', 'score')],
                    [fn.sum('v', 'wv'), fn.sum('score', 'z')])

What this do is matching the source key with destination query of each edge, and then summarize all information from the edges. This is a very clever solution. However, this works only with complete graph. If there is no edges between 0 and 1 in g1, this will not work because the information of node 1 will not come to node 0.
One solution I am thinking is tranform the graph into a complete one before the attention layer, and then delete them. (Similar with what was done in thos generative models tutorial).
Hope this makes clear what I want to do.:joy: Thanks.

Thank you and I understand better now. The problem is that attention implicitly suggests a dependency between pairs of tokens/nodes. With the scaled dot product attention in transformers, you are implicitly considering complete graphs. I think working with complete graphs is fine and you only need to construct these complete graphs for once during the preprocessing stage.

May I ask for some descriptions about the task and the dataset you are working on? Any particular reasons that you do not want to use complete graphs?

Thanks, Mufei. I found the SetTransformerEncoder in dgl/python/dgl/nn/pytorch/glob.py is what I want (not exactly, but quite similar). Anyway, this is one solution which I can try. Unfortunatly, I cannot get this work with the schnet in the model zoo. The functions used inside are quite complicated, thus I have no idea how to debug this (I think it should because of the backpropagation. The first bath graph feed in, it works. -> Backpropagation -> Second batch feed in, faild. I will continue try). I hope you can give me some suggestions here. Following is the model I tested. The only differences is a transformerencoder layer was added before the atom_dense_layer1. This code leaded to NAN train loss.:rofl:

# -*- coding:utf-8 -*-

import dgl
import torch as th
import torch.nn as nn
from layers import AtomEmbedding, Interaction, ShiftSoftplus, RBFLayer
import numpy as np
import dgl.nn.pytorch as dgl_nn

class SchNetModel(nn.Module):
    """
    SchNet Model from:
        Schütt, Kristof, et al.
        SchNet: A continuous-filter convolutional neural network
        for modeling quantum interactions. (NIPS'2017)
    """

    def __init__(self,
                 dim=64,
                 cutoff=5.0,
                 output_dim=1,
                 width=1,
                 n_conv=3,
                 norm=False,
                 atom_ref=None,
                 pre_train=None):
        """
        Args:
            dim: dimension of features
            output_dim: dimension of prediction
            cutoff: radius cutoff
            width: width in the RBF function
            n_conv: number of interaction layers
            atom_ref: used as the initial value of atom embeddings,
                      or set to None with random initialization
            norm: normalization
        """
        super().__init__()
        self.name = "SchNet"
        self._dim = dim
        self.cutoff = cutoff
        self.width = width
        self.n_conv = n_conv
        self.atom_ref = atom_ref
        self.norm = norm
        self.activation = ShiftSoftplus()

        if atom_ref is not None:
            self.e0 = AtomEmbedding(1, pre_train=atom_ref)
        if pre_train is None:
            self.embedding_layer = AtomEmbedding(dim)
        else:
            self.embedding_layer = AtomEmbedding(pre_train=pre_train)
        self.rbf_layer = RBFLayer(0, cutoff, width)
        self.conv_layers = nn.ModuleList(
            [Interaction(self.rbf_layer._fan_out, dim) for i in range(n_conv)])
       # Add an SetTransformerEncoder here
        self.attention = dgl_nn.glob.SetTransformerEncoder(dim, 2, dim, dim)
        self.atom_dense_layer1 = nn.Linear(dim, 64)
        self.atom_dense_layer2 = nn.Linear(64, output_dim)

    def set_mean_std(self, mean, std, device="cpu"):
        self.mean_per_atom = th.tensor(mean, device=device)
        self.std_per_atom = th.tensor(std, device=device)

    def forward(self, g):
        """g is the DGL.graph"""

        self.embedding_layer(g)
        if self.atom_ref is not None:
            self.e0(g, "e0")
        self.rbf_layer(g)
        for idx in range(self.n_conv):
            self.conv_layers[idx](upload://xdNzeWqQGgLZzLgJdEuj4G3UJeU.png)
        atom = self.attention(g, g.ndata["node"])
        atom = self.atom_dense_layer1(atom)
        atom = self.activation(atom)
        res = self.atom_dense_layer2(atom)
        g.ndata["res"] = res

        if self.atom_ref is not None:
            g.ndata["res"] = g.ndata["res"] + g.ndata["e0"]

        if self.norm:
            g.ndata["res"] = g.ndata[
                "res"] * self.std_per_atom + self.mean_per_atom
        res = dgl.sum_nodes(g, "res")
        return res


if __name__ == "__main__":
    g = dgl.DGLGraph()
    g.add_nodes(2)
    g.add_edges([0, 0, 1, 1], [1, 0, 1, 0])
    g.edata["distance"] = th.tensor([1.0, 3.0, 2.0, 4.0]).reshape(-1, 1)
    g.ndata["node_type"] = th.LongTensor([1, 2])
    model = SchNetModel(dim=2)
    atom = model(g)
    print(atom)

I am still test with the tencent alchemy dataset. In the official implementation of qlab, it is a complete graph (https://github.com/tencent-alchemy/Alchemy/blob/47c641a00ccc782aba0d008a4fbd7a137eb484b9/dgl/Alchemy_dataset.py#L178). However, the complete graph is comflicated with the my chemistry intuition (for two atoms which are far away, their impact to another should be omitted). Thus I delete some edges which have long distance. What I want to do is quite similar with this paper https://arxiv.org/pdf/1905.12712.pdf , you could take a look if you are interested.

Thanks a lot.

Did you examine the inputs and outputs of self.attention? Any unexpected behavior?

I’ve read that paper before. Currently we also support undirected graphs and you can try changing this line.