What's the calling order of reduce_func and apply_node_func in a TreeLSTM?

Hi DGL developers,

Thanks for this amazing framework! I apologize if this question was raised before.

I am trying to implement a Tree LSTM based on the example in your doc.
Assuming a simple balanced three-level binary tree, with leaf nodes being level 3 and root node being level 1. From your implementation, it seems like apply_node_func function will be called first on all leaf nodes, and the results will be delivered through message_func function to the parent nodes, which are then aggregated through the reduce_func function. I assume that each parent node will then call apply_node_func function again and restart this cycle.

Is this assumption correct? I am asking because when I put print statements in both reduce_func and apply_node_func functions, I see that sometimes a single apply_node_func is followed by multiple reduce_func but I never see multiple apply_node_func being called sequentially.

Thank you for your answer!

Hi, the order is always first message+reduce then apply-nodes.

In prop_nodes, because there is no incoming edges for leaf nodes, only the apply_node_func is executed.