Mol line graph in jtnn example

Hi,
I’m trying to implement neural message passing for quantum chemistry (https://arxiv.org/abs/1704.01212) in dgl and I’m following the message passing model example in JTNN but I’m confused about mol_line_graph update method at https://github.com/dmlc/dgl/blob/master/examples/pytorch/jtnn/jtnn/mpn.py#L170. I’m wondering why do we update the line graph in the for loop with depth and do update_all for mol_graph just once.

Thanks in advance!

The implementation follows Junction Tree Variational Autoencoder: https://arxiv.org/pdf/1802.04364.pdf

The for loop corresponds to the loopy-BP-like message passing in Equation 1 (which happens on edges, or equivalently, nodes of the line graph).

The last update_all simply gathers the edge states to nodes as in Equation 2.