Hi,
I am trying to understand how the LSTM aggregator works in the SAGEConv, used in the GraphSAGE layer (link of the docs of dgl).
There, the message creation/aggregation looks like this:
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
h_neigh = graph.dstdata['neigh']
With self._lstm_reducer defined by:
def _lstm_reducer(self, nodes):
m = nodes.mailbox['m'] # (B, L, D)
batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)))
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}
When I try this on my own heterogeneous graph, and I inspect the m before going into the LSTM, I have a m of shape torch.Size([26, 1, 3])
, which if I understand properly consists of [batch_size, number of nodes in the sequence, dimensionality of node features]. Is it correct?
In this case, I assume that rst
represents the hidden cell after a ‘sequence’ of 1 node inputted. (This would mean that the node only considers 1 neighbour.) Correct? Or is it the other way around, meaning that it would be the hidden cell after a ‘sequence’ of 26 nodes inputted, i.e. 26 neighbours would be considered?
If it is a sequence of 26 nodes inputted, how does DGL get the order of these nodes?
Thanks a lot in advance!