Invariant Gathering Operation

Im trying to work on graph model in material science where the gathering step in which adjacent nodes connected by edges are concatenated in a geometric manner. Unlike the usual gathering operation where sum and max are used, here we are using a concatenation operation where ( the order of concatenation is done in a geometric manner and it’s already a precalculated).
I have figured out the way to solve it using Pytorch alone, but I was wondering if there is any model in DGL where operation variant gathering is done.

Ideas are taken from the following paper given here (Model details on Page 4)

With concatenation, do you assume that all graphs have a same number of nodes? If not, do you use zero padding? It will also be very helpful if you can provide a code snippet for doing the operation with pure PyTorch.

In this particular model, all nodes have an equal number of adjacent nodes -
Therefore padding wasn’t necessary. In the below code, the neighbours are already pre-computed and passed along with the node features. So the features of neighbours are joined and a linear operation is applied to them.

Here is the PyTorch code -

def forward(self, feed_dict):
        """
        X_Sites: 2D tensor --> (Number of sites , size of node_feature)
        X_NSs: 3D tensor --> (Number of sites , Number of permutation , neighbor size)
        feed_dict["X_Sites"] , feed_dict["X_NSs"]     
        """
        updates = []
        for i in range(0 , len(feed_dict["X_NSs"])):
            
            # Neighbors of the current node i along with all the permutations

            X_NSs_i = feed_dict["X_NSs"][i]
            # Number of permutation * (node_feature_size * node_size) 
            X_site_i = feed_dict["X_Sites"][[X_NSs_i]].view( X_NSs_i.shape[0],-1 ) #.float()
            # Number of permutation * depth
            X_1 = self.conv_weights(X_site_i) + self.bias
            if self.UseBN:
                X_1 = self.batch_norm(X_1)
            
            # Shifted Softplus activatiion function
            X_2 = self.activation(X_1)
            # It zeros out few rows few permutations
            X_3 = self.dropout(X_2)
            X_4 = X_3.sum(axis = 0)
            updates.append(X_4)
        
        updates = torch.stack(updates , axis = 0)

In this particular model, all nodes have an equal number of adjacent nodes -

Ok, so we assume a setting of regular graphs.

How many graphs do you have? The code snippet seems to handle only one graph at a time. How should I interpret “permutation” here?

I think it will be better if I can represent the graph operation with a small example, (Sorry, as Im unable to explain it normally)

Let’s assume a given graph -


       A1         A2
          \     /
A6  __       A     __  A3
          /     \
       A5         A4

Let’s assume me are performing a graph operation on node A and its edges {A1 , A2 , A3 , A4 , A5 , A6}

In the gathering operation, invariance is maintained by selecting a geometrical order of permutation in which the edges are concatenated.

Permutation1 - {A1 , A2 , A3 , A4 , A5 , A6}
Permutation2- {A3 , A4 , A5 , A6,A1 , A2 }

So the gathering operation involves concatenation of different permutations of the neighbours.

This is a toy example and but each node has 19 neighbours in exact and 6 permutations. {This is because of lattice periodicity}

Let’s say it can handle only one graph at a time, and each graph has a variable number of nodes( 15 on average ) but each node has only fixed 19 neighbours. Im sorry if it sounds weird, but these neighbours are calculated using isomorphism algorithms on crystal lattice.

I see your point. You may implement a model as follows:

import dgl
import dgl.function as fn
import torch

# An example graph. 
# Note that the edges encode two permutations 
# of incoming edges for each node. E.g. {(1, 0), (2, 0)} and {(2, 0), (1, 0)}
src = torch.tensor([1, 2, 2, 1, 2, 0, 0, 2, 1, 0, 0, 1])
dst = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])
g = dgl.graph((src, dst))
# Initialize random node features for demo
g.ndata['h'] = torch.randn(g.num_nodes(), 2)

def reduce_func(nodes):
    print(nodes.mailbox['m'])
    return {'X_site': nodes.mailbox['m']}
g.update_all(fn.copy_u('h', 'm'), reduce_func)
num_permutations = 2

g.ndata['X_site'] corresponds to X_site_i for all nodes as in your example.

So the incoming edges of the mail_box[‘m’] are in the same order as in which the graph edges are defined. This helped me a lot, Thanks

1 Like