Linear Projection on Aggregation of the Neighbourhood

Can anyone explain the below lines on the SAGEConv source code?

Let’s say our input and output are of the same dimensions. In this case, lin_before_mp will be equal to False, then if not lin_before_mp becomes True, which in turn applied a linear layer on the neighbourhood aggregated tensor.

I don’t see a point here. Am I missing anything?

# Determine whether to apply linear transformation before message passing A(XW)
lin_before_mp = self._in_src_feats > self._out_feats

# Message Passing
if self._aggre_type == 'mean':
    graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
    graph.update_all(msg_fn, fn.mean('m', 'neigh'))
    h_neigh = graph.dstdata['neigh']
    if not lin_before_mp:
        h_neigh = self.fc_neigh(h_neigh)

This is just a performance optimization.

Suppose A is the adjacency matrix, X is the feature tensor, W is the weight matrix.

This lin_before_mp is used to determine the order of association: A(XW) or (AX)W.

1 Like