Broadcast readout output to each node in BatchedDGLGraph?



I’m trying to implement a function which performs, say, the mean_nodes readout function on a BatchedDGLGraph, then broadcasts the mean of each graph back to every node in its respective graph. Kind of a way to give each node information about the rest of the graph. Is there a way to do this elegantly without having to unbatch / rebatch or use a for-loop?




Unfortunately you have to unbatch and use the for-loop. Otherwise, you may need to write a custom operator for pytorch or mxnet to avoid loop. The operator behind mean_nodes is at here, for your reference.

Are there any examples about how to define a new readout function?

In case anyone from the future needs to do this, here is how:

def dgl_broadcast_mean(g):
    mean = dgl.mean_nodes(g, 'my_node_feature')
    idx = np.repeat(np.arange(g.batch_size), g.batch_num_nodes)
    idx = torch.LongTensor(idx).to(mean.device)
    g.ndata['my_node_feature_graph_mean'] = mean[idx]
    return g