Order in SAGEConv LSTM aggregator

Hi,
I am trying to understand how the LSTM aggregator works in the SAGEConv, used in the GraphSAGE layer (link of the docs of dgl).

There, the message creation/aggregation looks like this:

graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
h_neigh = graph.dstdata['neigh']

With self._lstm_reducer defined by:

def _lstm_reducer(self, nodes):
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

When I try this on my own heterogeneous graph, and I inspect the m before going into the LSTM, I have a m of shape torch.Size([26, 1, 3]), which if I understand properly consists of [batch_size, number of nodes in the sequence, dimensionality of node features]. Is it correct?

In this case, I assume that rst represents the hidden cell after a ‘sequence’ of 1 node inputted. (This would mean that the node only considers 1 neighbour.) Correct? Or is it the other way around, meaning that it would be the hidden cell after a ‘sequence’ of 26 nodes inputted, i.e. 26 neighbours would be considered?

If it is a sequence of 26 nodes inputted, how does DGL get the order of these nodes?

Thanks a lot in advance!

When I try this on my own heterogeneous graph, and I inspect the m before going into the LSTM, I have a m of shape torch.Size([26, 1, 3]) , which if I understand properly consists of [batch_size, number of nodes in the sequence, dimensionality of node features]. Is it correct?

Strictly speaking, this is [batch_size, number_of_messages/neighbors, dimensionality of node features]. This is related to a “degree bucketing” mechanism for computation optimization.

In this case, I assume that rst represents the hidden cell after a ‘sequence’ of 1 node inputted. (This would mean that the node only considers 1 neighbour.) Correct? Or is it the other way around, meaning that it would be the hidden cell after a ‘sequence’ of 26 nodes inputted, i.e. 26 neighbours would be considered?

It updates the representations for a batch of nodes by considering all in-neighbors for each node. In this particular case, all 26 nodes only have one in-neighbor.

Thanks @mufeili for the answer.

I understand that in this case, the shape of my m indicates that all 26 nodes only have one in-neighbor. However, when I inspect my graph here (before going into the LSTM), I see that the nodes have usually more than 1 in-neighbor.

def forward(self,
            graph,
            x):
        h_neigh, h_self = x
     -> print(graph.in_degrees(etype=etype))
        if self._aggre_type == 'lstm':
            graph.srcdata['h'] = h_neigh
            graph.update_all(
                fn.copy_src('h', 'm'),
                self._lstm_reducer)
            h_neigh = graph.dstdata['neigh']

Output:
tensor([26, 5, 15, ..., 2, 1, 2])

And in a logical sense (based on my data), I feel that the nodes should all have multiple in-neighbor. Would you have any intuition about why m still only includes nodes with 1 in-neighbor, even though the graph has nodes with all multiple in-neighbors?

Thanks again!

As I mentioned above, this is related to a degree-bucketing mechanism for computation optimization. Basically DGL batches nodes with a same in-degree for computation optimization. The batch you printed happened to be the batch for nodes with in-degree 1. If you print more batches, you will see batches for nodes with a larger in-degree.

Thanks @mufeili for the quick answer. Sorry, I had not quite understood it at first; thanks for clarifying this up.

I printed more batches, and indeed, those nodes have multiple in-neighbors. Per example, one of the batches has a m of size [46,13,3], meaning that there are 13 in-neighbors for all those nodes. LSTM works with a sequence of inputs. How does DGL get the order of the sequence of those 13 in-neighbors?

I would assume that this is based on edges indexing, i.e. the edges with the smallest indices will be considered as ‘preceding’ the edges with larger indices. Is that correct? “Within each bucket, the edges are ordered by the edge IDs for each node.” from DGL docs

Thanks again!

Yes, I think you are right.

@mufeili Thank you for your answers!

1 Like