Dataloader on GPU, RuntimeError: Prefetcher thread timed out at 30 seconds

Hi I got an error when I use dataloader with GPU:

I run minibatch training with my own dataset and I got the following error:

---------------------------------------------------------------------------
Empty                                     Traceback (most recent call last)
File ~/miniconda3/lib/python3.10/site-packages/dgl/dataloading/dataloader.py:499, in _PrefetchingIter._next_threaded(self)
    498 try:
--> 499     batch, feats, stream_event, exception = self.queue.get(timeout=prefetcher_timeout)
    500 except Empty:

File ~/miniconda3/lib/python3.10/queue.py:179, in Queue.get(self, block, timeout)
    178 if remaining <= 0.0:
--> 179     raise Empty
    180 self.not_empty.wait(remaining)

Empty: 

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[6], line 3
      1 for epoch in range(num_epoch):
      2     start=time.time()
----> 3     for input_nodes, positive_graph, negative_graph, blocks in train_dataloader:
      4         inner_start=time.time()
      5         blocks = [b.to(torch.device(device)) for b in blocks]

File ~/miniconda3/lib/python3.10/site-packages/dgl/dataloading/dataloader.py:512, in _PrefetchingIter.__next__(self)
    510 def __next__(self):
    511     batch, feats, stream_event = \
--> 512         self._next_non_threaded() if not self.use_thread else self._next_threaded()
    513     batch = recursive_apply_pair(batch, feats, _assign_for)
    514     if stream_event is not None:

File ~/miniconda3/lib/python3.10/site-packages/dgl/dataloading/dataloader.py:501, in _PrefetchingIter._next_threaded(self)
    499     batch, feats, stream_event, exception = self.queue.get(timeout=prefetcher_timeout)
    500 except Empty:
--> 501     raise RuntimeError(
    502         f'Prefetcher thread timed out at {prefetcher_timeout} seconds.')
    503 if batch is None:
    504     self.thread.join()

RuntimeError: Prefetcher thread timed out at 30 seconds.

Provide some details of my device and graph below:
pytorch version 1.12.1
torch version 1.12.1 + cuda116

Graph details…
Graph(num_nodes={‘item’: 9432, ‘user’: 7120017},
num_edges={(‘item’, ‘clicked-by’, ‘user’): 255122481, (‘user’, ‘click’, ‘item’): 255122481},
metagraph=[(‘item’, ‘user’, ‘clicked-by’), (‘user’, ‘item’, ‘click’)])

negative_sampler = dgl.dataloading.negative_sampler.Uniform(4)
sampler = dgl.dataloading.NeighborSampler([2,2])
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler, negative_sampler=negative_sampler)
train_eid_dict = {
etype: g.edges(etype=etype, form=‘eid’)
for etype in g.canonical_etypes}
train_dataloader = dgl.dataloading.DataLoader(
# The following arguments are specific to DataLoader.
g, # The graph
train_eid_dict, # The edges to iterate over
sampler, # The neighbor sampler
device=‘cuda’, # Put the MFGs on CPU or GPU
# The following arguments are inherited from PyTorch DataLoader.
batch_size=4194304, # Batch size
shuffle=True, # Whether to shuffle the nodes for every epoch
drop_last=False, # Whether to drop the last incomplete batch
num_workers=1 # Number of sampler processes
)

Does anyone know how to fix this? Thanks in advance!

cc @BarclayII for awareness

It might be possible that your batch size (4 million) is so large that the subgraph sampling cannot be completed within 30 seconds. You could either set the DGL_PREFETCHER_TIMEOUT environment variable to a larger value, decrease the batch size, or setting num_workers=0 to enable intra-minibatch parallelism.

Thanks. Smaller batch size works but it also slows down training. Can you please tell me what’s an easy way to reset DGL_PREFETCHER_TIMEOUT to larger value without re-compiling the source code?

In the source code I saw:

prefetcher_timeout = int(os.environ.get('DGL_PREFETCHER_TIMEOUT', '30'))

Do I just overwrite it in my model code while I import dgl as usual?

os.environ['DGL_PREFETCHER_TIMEOUT'] = str(300)

You code is correct; just set it before import DGL and you should be good.

Thanks, @BarclayII.

I have some follow-up questions.

  1. If I use small batch size, training will take longer (even with GPU). But if I use large batch size (e.g. 500000 edges/minibatch) in data loader, I encounter the following error code:
Traceback (most recent call last):
  File "/home/hadoop/GNN_SGD_NeighborSampling.py", line 547, in <module>
    for input_nodes, positive_graph, negative_graph, blocks in train_dataloader:
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/dataloading/dataloader.py", line 512, in __next__
    self._next_non_threaded() if not self.use_thread else self._next_threaded()
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/dataloading/dataloader.py", line 507, in _next_threaded
    exception.reraise()
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/utils/exception.py", line 57, in reraise
    raise exception
RuntimeError: Caught RuntimeError in prefetcher.
Original Traceback (most recent call last):
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/dataloading/dataloader.py", line 380, in _prefetcher_entry
    batch, feats, stream_event = _prefetch(batch, dataloader, stream)
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/dataloading/dataloader.py", line 338, in _prefetch
    batch = recursive_apply(batch, _record_stream, current_stream)
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/utils/internal.py", line 1038, in recursive_apply
    return [recursive_apply(v, fn, *args, **kwargs) for v in data]
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/utils/internal.py", line 1038, in <listcomp>
    return [recursive_apply(v, fn, *args, **kwargs) for v in data]
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/utils/internal.py", line 1038, in recursive_apply
    return [recursive_apply(v, fn, *args, **kwargs) for v in data]
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/utils/internal.py", line 1038, in <listcomp>
    return [recursive_apply(v, fn, *args, **kwargs) for v in data]
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/utils/internal.py", line 1040, in recursive_apply
    return fn(data, *args, **kwargs)
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/dataloading/dataloader.py", line 307, in _record_stream
    x.record_stream(stream)
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/heterograph.py", line 5608, in record_stream
    col.record_stream(stream)
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/frame.py", line 508, in record_stream
    self.data.record_stream(stream)
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/frame.py", line 231, in data
    self.storage = F.copy_to(self.storage, self.device[0], **self.device[1])
  File "/home/hadoop/miniconda3/lib/python3.10/site-packages/dgl/backend/pytorch/tensor.py", line 117, in copy_to
    return input.cuda(**kwargs)
RuntimeError: CUDA out of memory. Tried to allocate 790.00 MiB (GPU 0; 14.56 GiB total capacity; 13.14 GiB already allocated; 696.44 MiB free; 13.25 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I assume that my graph/subgraph is too big so GPU doesn’t have enough memories. Is there any method to train with large minibatch size so that training time is not that long?

  1. In the data_loader, I have to specify number of neighbors in the sampling. If the minibatch is too big and GPU runs out of memories, does it mean I can not set number of neighbors too large, since the subgraph will become too big if we have to store more information from more neighbors?

Thanks in advance!

  1. I assume that you are doing link prediction since you mention you are training on edges. Our experience is that if you are training on a large graph, you don’t have to iterate over all the edges in a single “epoch”. Link prediction models often get considerably good results even when only trained on a small subset of the edges.
  2. The size of neighbor sampling results grow exponentially w.r.t. number of layers. You could try using one of the subgraph samplers, e.g. ShaDowKHopSampler or SAINTSampler.

This topic was automatically closed 30 days after the last reply. New replies are no longer allowed.