Batching Heterogenous graphs

Is there any support for batching Heterogenous graphs? If not, any tips as to the best approach to implement it myself?

I will work on this, but will need some more time. For the time being, Iโ€™d suggest you rather explicitly construct multiple homographs and batching each type of graphs separately.

Thanks! We thought of this hack, but its not so straight-forward when there are edges between different node typesโ€ฆ

@lederg I was looking for batching heterograph and get a workaround with the code block below that you may find useful. I did not test the speed, but I think it would not be too bad, at least better than looping over individual graphs.

Note that the below code only works for nodes and their associated features. No edge features and the order of edges may change in the batched graph.

def graph_list_to_batch(graph_list):

    # query graph info
    g = graph_list[0]
    ntypes = g.ntypes
    etypes = g.canonical_etypes
    attrs = {t: g.nodes[t].data.keys() for t in ntypes}

    # graph connectivity
    current_num_nodes = {t: 0 for t in ntypes}
    connectivity = {t: [] for t in etypes}
    for g in graph_list:
        for t in etypes:
            src, edge, dest = t
            conn = []
            for i in range(g.number_of_nodes(src)):
                i_prime = i + current_num_nodes[src]
                conn.extend(
                    [
                        (i_prime, j + current_num_nodes[dest])
                        for j in g.successors(i, edge)
                    ]
                )
            connectivity[t].extend(conn)
        for t in ntypes:
            current_num_nodes[t] += g.number_of_nodes(t)

    # create batched graph
    batch_g = dgl.heterograph(connectivity)

    # graph data (node only)
    slices = {n: [] for n in ntypes}
    data = {n: defaultdict(list) for n in ntypes}
    for g in graph_list:
        for t in ntypes:
            for a in attrs[t]:
                data[t][a].append(g.nodes[t].data[a])
            slices[t].append(g.number_of_nodes(t))
    # batch data
    for t, dt in data.items():
        for a, d in dt.items():
            data[t][a] = torch.cat(d)
    # add batch data to batch graph
    for t in ntypes:
        batch_g.nodes[t].data.update(data[t])

    # attach graph list and data slices for later split
    batch_g.graph_list = graph_list
    batch_g.node_slices = slices

    return batch_g
1 Like

@lederg @mjwen The support for batching/unbatching DGLHeteroGraph has been added and you can use it by installing from source. See the examples. Let me know if you have any questions.

Thanks @mufeili! Iโ€™ve tried it and it works out of the box very nicely :fireworks::fireworks::fireworks: