Save graph_data for dgl.graph when batch feeding

Hi,

Thanks for the development of dgl which make chemical indirect graph problems much easier to code.
I want to implement the MEGNet proposed in https://arxiv.org/abs/1812.05055 with dgl graph. One difference of MEGNet is the introduction of a state feature which is a vector of two elements [“average_weight”, “bonds_per_atom”]. The follwing function should simply explain what I want to do. However, this cause problem when batch feeding since batch make dgl graph into a big graph. (Suppose each batch contain 20 molecules, the shape of g.graph_data[“s”] should be [20, 2], however, this is still [2, ] because the number of nodes in a batch dgl graph equal to the total number of atoms in 20 moelcules)

class StateGenerator(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, g, p_name="s"):
        """Input type is dgl graph"""
        average_atomic_weight = th.sum(g.ndata["node_type"])/g.number_of_nodes()
        bonds_per_atom = g.number_of_edges()/g.number_of_nodes()
        g.graph_data = {}
        g.graph_data[p_name] = th.tensor([average_atomic_weight, bonds_per_atom], dtype=th.float, requires_grad=False, device = "cuda")
        return g

Could anyone please give me some suggestions on this?
Thanks.

We plan to support this soon.

Current workaround is you can change the collate function. For example, originally

def collate(graphs):
    return dgl.batch(graphs)

Now

def collate(graphs):
    g = dgl.batch(graphs)
    g.gdata = {p_name: torch.cat([g.gdata[p_name] for g in graphs]) # Something like this
    return g

Feel free to ask if you have any further question

1 Like