DGL Dataloader num workers > 0 problems

Hi,

I am using the GAT model, with the standard batched graph classification framework in the examples. However, I am trying to use multiple workers for the pytorch dataloader to speed up the creation of batches. However, I run into problems, with this? See below…

dgl._ffi.base.DGLError: Cannot update column of scheme Scheme(shape=(256,), dtype=torch.float32) using feature of scheme Scheme(shape=(19,), dtype=torch.float32).

Does the creation of batch graphs along with the collate function support the data loader configured to use multiple processes?

Yep we support multiple processes to load data, however could you please show your code so that we could find where the problem is?

Hi zihao.
Here I have met the same problem when set the num_works of dataloader > 0.

This is my collate function:

def collate(samples, device):
    # collate for generating batched data
    # The samples is a list of pairs (graph, label)
    graphs, labels = map(list, zip(*samples))
    for graph in graphs:
        for k in graph.node_attr_schemes().keys():
            graph.ndata[k] = graph.ndata[k].to(device)
        for k in graph.edge_attr_schemes().keys():
            graph.edata[k] = graph.edata[k].to(device)

    batched_graph = dgl.batch(graphs)
    return batched_graph, th.tensor(labels, device=device)

And DataLoader.

data_loader = DataLoader(dataset,batch_size=batch_size, num_workers=4, shuffle=False, collate_fn=lambda samples: collate(samples, self.device))

It works fine when num_workers is 0. However, when I increase it to more than 0, problem occurred like this.

RuntimeError: Traceback (most recent call last):
  File "/data/home/xxx/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/data/home/cqlu/DGL_mol/dataset.py", line 100, in <lambda>
    collate_fn=lambda samples: collate(samples, self.device))
  File "/data/home/cqlu/DGL_mol/dataset.py", line 16, in collate
    graph.ndata[k] = graph.ndata[k].to(device)
RuntimeError: CUDA error: initialization error

You should never move data from CPU to GPU in collate function, this is because subprocess might be executed before main process (and before cuda initialization of PyTorch). Beside, this may result in GPU out-of-memory problem (the subprocesses generate too many samples that have not been used by the trainer yet, but they consume GPU memory).
The right way is to move data to corresponding device in the main process. In your case, change your collate function to:

def collate(samples):
    # collate for generating batched data
    # The samples is a list of pairs (graph, label)
    graphs, labels = map(list, zip(*samples))
    for graph in graphs:
        for k in graph.node_attr_schemes().keys():
            graph.ndata[k] = graph.ndata[k]
        for k in graph.edge_attr_schemes().keys():
            graph.edata[k] = graph.edata[k]

    batched_graph = dgl.batch(graphs)
    return batched_graph, th.tensor(labels)

and move data to your device in the main process at each iteration:

data_loader = DataLoader(dataset,batch_size=batch_size, num_workers=4, shuffle=False, collate_fn=collate)
for g, y in data_loader:
    for k in g.node_attr_schemes().keys():
        g.ndata[k] = g.ndata[k].to(device)
    for k in g.edge_attr_schemes().keys():
        g.edata[k] = g.edata[k].to(device)
    y = y.to(device)
    ...

There will be a tutorial on multi-process data loader recently, please stay tuned!

Thank you for replying.
I had tried move the data from CPU to GPU in the main process while that may cause low GPU-util. I had tried use cuda.stream to do the pre-fetch to speed up it, but it doesn’t work.
Looking forward to your tutorial on multi-process data loader.

Hi! I’m afraid I am getting the same error as user hellboy5 when num_workers>0. Please note that this is a different error from the one mentioned by user lunar.

def batcher(batch):
    lg_batch = dgl.batch(batch)
    features = lg_batch.ndata['type'].reshape(-1,1).float()
    labels = lg_batch.ndata['response'].reshape(-1,1)
    return [lg_batch, features, labels]

Dataloader:

train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=batcher, num_workers=4)

And the error message:

DGLError: Cannot update column of scheme Scheme(shape=(1,), dtype=torch.float32) using feature of scheme Scheme(shape=(8,), dtype=torch.float32).

When num_workers=0 I can train the GAT model successfully. Any insights on what could be missing? Thank you!

That kind of error message happens when the shape of your data in a field is given, and you would like to update a subset of the data with another shape, for example:

>>> import torch as th
>>> import dgl
>>> g = dgl.DGLGraph()
>>> g.add_nodes(5)
>>> g.ndata['x'] = th.rand(5, 4)
>>> g.nodes[1].data['x']
tensor([[0.9324, 0.1661, 0.0955, 0.2504]])
>>> g.nodes[1].data['x'] = th.rand(1, 3)
DGLError: Cannot update column of scheme Scheme(shape=(4,), dtype=torch.float32) using feature of scheme Scheme(shape=(3,), dtype=torch.float32).

For your case, I can’t infer where you met this problem, could you elaborate more with a minimum reproduction of this error when you use multiple workers?

Hi, recently I tried your advice about moving data to gpu in the main progress. It works fine when I set num_workers=0.
However, when I set the num_worker > 0. There comes a RuntimeError:

/run/cqlu/cqlu/DGL_mol/flow.py in train(self, epochs, batch_size, dataset, early_stop, lr, decay_rate, train_size, device, resume)
    103             # train
    104             self.model.train()
--> 105             for idx, (bg, label) in enumerate(train_loader):
    106                 for k in bg.node_attr_schemes().keys():
    107                     bg.ndata[k] = bg.ndata[k].to(device)

~/cq/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    580                 self.reorder_dict[idx] = batch
    581                 continue
--> 582             return self._process_next_batch(batch)
    583
    584     next = __next__  # Python 2 compatibility

~/cq/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_next_batch(self, batch)
    606                 raise Exception("KeyError:" + batch.exc_msg)
    607             else:
--> 608                 raise batch.exc_type(batch.exc_msg)
    609         return batch
    610

RuntimeError: Traceback (most recent call last):
  File "/data/home/cqlu/cq/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/run/cqlu/cqlu/DGL_mol/dataset.py", line 125, in <lambda>
    collate_fn=lambda samples: collate(samples, self.device))
  File "/run/cqlu/cqlu/DGL_mol/dataset.py", line 14, in collate
    batched_graph = dgl.batch(graphs)
  File "/data/home/cqlu/cq/lib/python3.6/site-packages/dgl/batched_graph.py", line 354, in batch
    return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
  File "/data/home/cqlu/cq/lib/python3.6/site-packages/dgl/batched_graph.py", line 192, in __init__
    for key in node_attrs}
  File "/data/home/cqlu/cq/lib/python3.6/site-packages/dgl/batched_graph.py", line 192, in <dictcomp>
    for key in node_attrs}
  File "/data/home/cqlu/cq/lib/python3.6/site-packages/dgl/backend/pytorch/tensor.py", line 100, in cat
    return th.cat(seq, dim=dim)
RuntimeError: CUDA error: initialization error

Have you met similar problem like this?

That’s weird, could you provide me with a snippet that could reproduce your bug?

import dgl                                                                 
import numpy as np                                                         
import torch as th                                                         
from torch.utils.data import DataLoader                                    
from sklearn.model_selection import train_test_split                       
                                                                           
                                                                           
def collate(samples):                                                      
    # collate for generating batched data                                  
    # The samples is a list of pairs (graph, label)                        
    graphs, labels = map(list, zip(*samples))                              
    batched_graph = dgl.batch(graphs)                                      
    return batched_graph, th.tensor(labels)                                
                                                                           
                                                                           
graphs = []                                                                
for i in range(10):                                                        
    g = dgl.DGLGraph()                                                     
    g.add_nodes(5)                                                         
    g.ndata["a"] = th.Tensor([[i] for i in range(5)])                      
    graphs.append(g)                                                       
                                                                           
labels = [i for i in range(10)]                                            
dataset = [(i, j) for i, j in zip(graphs, labels)]                         
device = "cuda:0"                                                          
                                                                           
dataloader_1 = DataLoader(dataset,                                         
                          batch_size=5,                                    
                          shuffle=False,                                   
                          collate_fn=collate,                              
                          num_workers=0)                                   
                                                                           
for bg, label in dataloader_1:                                             
    for k in bg.node_attr_schemes().keys():                                
        bg.ndata[k] = bg.ndata[k].to(device)                               
    for k in bg.edge_attr_schemes().keys():                                
        bg.edata[k] = bg.edata[k].to(device)                               
    label = label.to(device)                                               
    print(bg.ndata["a"])                                                   
    print(label)                                                           
print("Dataloader  success")                                              

It works when num_workers = 0, but it fail when num_workers = 4.

I can’t reproduce your bug when setting num_workers=4 in my environment, try using latest version of PyTorch (v1.1) and DGL (v0.3) and see if the problem still exists.

Thank you, upgrading dgl from v0.2 to v0.3 fixed this problem.