I would expect DataLoader to load batches concurrently to the main process, refilling the buffer as soon as a batch is consumed from the buffer. However, when I track the utilization of GPU and order of loading vs execution I see some different behavior:
- loading of the whole buffer (expected)
- consuming the whole buffer by execution of all the batches until the buffer is empty (not expected)
- loading the whole buffer again without parallel execution (not expected)
- goto 2.
This obviously results in dips in GPU utilization when in step 3.
num_workers >= 1
pin_memory = True/False (doesn’t influence the described behavior)
Did anyone experience the same? What could be the issue?
Your experience would mean that e.g. if you are using 10 workers, the GPU would starve after 10 batches, as each worker would have to load the next batch?
I haven’t seen this behavior yet and only see a latency in the first epoch of the
DataLoader, where each worker is loading the batch. Afterwards the data loading time should be hidden in the background, if your actual training workload is large enough.
Yes, that’s more or less what happens.
After some isolated tests of just loading batches from the dataloader without doing any GPU processing, I think the problem could stem from the size of the epoch and how the dataloader iterates over epochs. The current dataset only has 50 samples (long videos) from which smaller samples are taken. The dataloader iterates over epochs of 50. Is it possible that the dataloader only prefetches until the end of the epoch is reached and after the last batch is prefetched, it waits until completion of all batches on the GPU? Afterwards it somehow “restarts”? This could explain the behavior I’ve seen lately: Near 100% GPU utilization with regular intermediate dips to 0%
(I’m using a map-style dataset)
Probably that’s the issue right here (https://pytorch.org/docs/stable/data.html):
Multi-process data loading
Setting the argument
num_workers as a positive integer will turn on multi-process data loading with the specified number of loader worker processes.
In this mode, each time an iterator of a
DataLoader is created (e.g., when you call
num_workers worker processes are created. At this point, the
collate_fn , and
worker_init_fn are passed to each worker, where they are used to initialize, and fetch data. This means that dataset access together with its internal IO, transforms (including
collate_fn ) runs in the worker process.
Shouldn’t the implementation of torch.utils.data.DataLoader take care of seamless prefetching over epochs?
In a lot of use cases training a complete epoch takes a lot of more time than reinitializing the worker processes, so that the epoch start could be ignored.
It’s probably an edge case, but you could also manipulate some arguments of the
DataLoader itself or the underlying
Dataset between epochs.
However, if this behavior is the bottleneck in your code, you might try to e.g. increase the length of your
Dataset artificially (
__len__ could return multiple epochs and in
__getitem__ you could use the
% operator) or warm-up the next iterator via
loader_iter2 = iter(train_loader) during one of the last iterations of your training.
I could replicate the issue with two extreme cases (artificial):
- epoch of 10000 samples
- epoch of 1 sample
-> 1. has 4 times higher bitrate when loading my data, which confirms the suspected issue.
That’s exactly what I came up with aswell (first case)
I agree with your statement, that this is probably en edge case. However, I would expect torch.utils.data.DataLoader to be an infinite iterator. I therefore think seamless batch loading should be implemented natively because: in any case, if you learn over more than 1 epoch, there’s an inherent unnecessary downtime (which is probably not expected by the users). If you learn over rather small epochs, this downtime can lead to significant lower GPU utilization, wasting precious processing time. Do you think this could be an easy fix in torch.utils.data.DataLoader?
There might be a “fix”, but it would be interesting to see the performance improvements for some typical use cases first.
If the speedup is marginal (besides extreme cases, where an epochs use 1 sample), I’m not sure it’s worth the effort.
How big is the slowdown in each epoch for your current use case?