Create subgraphs

I just learned that the graph convolution of multiple subgraphs can be carried out using batch graph (dgl.batch).Thank you very much.
If multiple subgraphs are generated simultaneously based on a heterogeneous graph, is there any fast method? I use cycles to create subgraphs. I find that it takes a lot of time to create subgraphs.

Could you provide more details? For example, some code/pseudo-code would be appreciated.

g = dgl.heterograph(graph_data) # heterograph
all_sub = [] # Store all subgraphs
# len(sub_ids) = 300
sub_ids = [[[type1],[type2],[type3]], ...] # Store the ID required to generate the subgraph
for sub_id in sub_ids:
     # It takes a lot of time
     sub_g = dgl.node_subgraph(g, {'xx':sub_id[0], 'yy':sub_id[1], 'zz':sub_id[2]})
     all_sub.append(sub_g)
# Subgraph batch processing
bg = dgl.batch(all_sub)

I don’t know any clever way to speed up, so I’m confused

I think depending on your scenario, there are multiple ways to accelerate it.

  • If the sub_ids are the same across all iterations, you could pre-generate all the subgraphs and batch them before training. The cost will be paid once and is usually small compared with the training time.
  • If the sub_ids are the same across all iterations but at each iteration you wish to batch different subgraphs as mini-batch inputs to your model, you could create a Dataset with each data sample being a subgraph extracted from sub_ids. You can then use the GraphDataLoader which speeds up data batching via multiprocessing.
  • If the sub_ids are different at each training iteration, then it becomes a graph sampling procedure. You can write a custom sampler where you call dgl.node_subgraph in its sample method. Below is a demo code that may not run but just to show the concept:
class MySampler(dgl.dataloading.Sampler):
    def __init__(self):
        super().__init__()
    def sample(self, g, sub_id):
        return g.node_subgraph(sub_id)

g = dgl.heterograph(graph_data) # heterograph
all_sub_ids = {'type1' : ..., 'type2' : ..., ...} 
sampler = MySampler()
dataloader = dgl.dataloading.DataLoader(
    g,
    all_sub_ids,
    sampler,
    batch_size=xxx,  # num nodes of each subgraph
    ...,   # you can play with the arguments such as num_workers to accelerate the process
)
subgs = []
for subg in dataloader:
    subgs.append(subg)
    if len(subgs) == 300:  # the number of subgs you'd like to batch
        batched_subg = dgl.batch(subgs)
        train_model(batch_subg)
        subgs = []

Thank you very much.
I should be the third case.