Different reducer function for different node types

If I’m trying to do something like GraphSAGE except I want different nodes to have different aggregation functions how would I go about it.

For example, suppose I have a graph of the form

4 -> 0 -> 1 <- 2 -<- 3 

wherein nodes 0, 1, 2 are of type A (for which I want to use an LSTM reducer) and nodes 3, 4 are of type B for which I want to use a mean aggregation function.

Not totally sure how this would be handled within the dgl framework, any help would be appreciated. Thanks!

1 Like

You can do

import dgl.function as fn

g.pull([0, 1, 2], message_func=..., reduce_func=LSTM)
g.pull([3, 4], message_func=..., reduce_func=fn.mean(...))

The documentation of pull can be found here.

1 Like