About node features in batch training

Hello everyone!I have a question about node features in batch training.
If all node features cannot stay in GPU(OOM), should batch node features on GPU be freed every batch iteration? If so, where do DGL do this work?

If all node feature cannot fit into GPU, you could keep it on CPU. And once obtain the batch from dataloader, just slice feature according to the input_nodes and move to GPU device. In this way, every batch feature will be freed after each mini-batch iteration. here’s a reference: dgl/train_dist.py at ea706cae58ef66d772e2ed3f16c67de5370397d6 · dmlc/dgl · GitHub In this example, both the graph and node feature lie on cpu. But you could also leave graph on GPU while keep feature on CPU which enables sampling on GPU directly.

Thanks for your reply! It seems that if I transfer batch node features manually, they can be freed automatically after each mini-batch iteration? How DGL achieve this process?
And if I use code like mfgs[0].srcdata['feat'] to get batch node features, I mean:

sampler = dgl.dataloading.NeighborSampler([4, 4])

train_dataloader = dgl.dataloading.DataLoader(
    graph,             
    train_nids,         
    sampler,            # The neighbor sampler
    device=device,      # Put the sampled MFGs on CPU or GPU
    batch_size=1024,   
    shuffle=True,       
    drop_last=False,    
    num_workers=0,      
)
def train():
    for epoch in range(EPOCHS):
        model.train()
        for step, (input_nodes, output_nodes, mfgs) in enumerate(train_dataloader):
            # batch nodes feats transfered first time they are used
            inputs = mfgs[0].srcdata['feat']
            # batch nodes label
            labels = mfgs[-1].dstdata['label']
            predictions = model(mfgs, inputs)
            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

train()

Will batch node features be freed after each mini-batch iteration?

yes, batch node features(inputs) will be freed at some time(when no reference to mfgs). it depends on python gc.

Thank for your reply! I know little about python gc, does it has something to do with GPU memory management?

GPU memory is handled by PyTorch an underlying CUDA library.

Thank you so much for your precious time and reply!

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.