I’m trying to implement a function which performs, say, the mean_nodes readout function on a BatchedDGLGraph, then broadcasts the mean of each graph back to every node in its respective graph. Kind of a way to give each node information about the rest of the graph. Is there a way to do this elegantly without having to unbatch / rebatch or use a for-loop?




Unfortunately you have to unbatch and use the for-loop. Otherwise, you may need to write a custom operator for pytorch or mxnet to avoid loop. The operator behind mean_nodes is at here, for your reference.