Hi DGL developers,
Thanks for this amazing framework! I apologize if this question was raised before.
I am trying to implement a Tree LSTM based on the example in your doc.
Assuming a simple balanced three-level binary tree, with leaf nodes being level 3 and root node being level 1. From your implementation, it seems like apply_node_func
function will be called first on all leaf nodes, and the results will be delivered through message_func
function to the parent nodes, which are then aggregated through the reduce_func
function. I assume that each parent node will then call apply_node_func
function again and restart this cycle.
Is this assumption correct? I am asking because when I put print statements in both reduce_func
and apply_node_func
functions, I see that sometimes a single apply_node_func
is followed by multiple reduce_func
but I never see multiple apply_node_func
being called sequentially.
Thank you for your answer!