Do I understand the following correctly?
When num_workers >=1
, the main process pre-loads prefetch_factor * num_workers
batches. When the training loop consumes one batch, the corresponding worker loads the next batch in its queue.
If this is the case, let’s go through an example.
NOTE: I have chosen the numeric values for illustration purposes and
have ignored various overheads in this example. The accompanying code example uses these numbers.
Say I have num_workers=4
, prefetch_factor=4
, batch_size=128
. Further assume, it takes 0.003125 s
to fetch an item from a source database and the train step takes 0.05 s
.
Now, each batch would take 0.003125 * 128 = 0.4 s
to load.
With a prefetch_factor=4
and num_workers=4
, first, 4*4=16
batches will be loaded.
Once the 16
batches are loaded, the first train step consumes 1 batch and takes 0.05 s
. Say worker[0]
provided this batch and will start the process to generate a new batch to replenish the queue. Recall fetching a new batch takes 0.4 s
.
Similarly, the second step consumes one more batch and the corresponding worker (worker[1]
in this example) starts the data fetching process.
The first 8 train steps would take 0.05*8=0.4s
. By this time, 8
batches have been
consumed and worker[0]
has produced 1 batch. In the next step, 1 batch is consumed and worker[1]
produces a new batch. worker[1]
had started the data fetching process in the second train step which would now be completed.
Following this we can see, each subsequent train step will consume 1 batch and one of the workers will produce 1 batch, keeping the dataloader queue to have always 8
batches. This means that the train step is never waiting for the data loading process as there are always 8 batches in the buffer.
I would expect this behavior regardless of the data size of the batch given num_workers
, prefetch_factor
are large enough. However, in the following code example that is not case.
In the code below, I define a custom iterable that returns a numpy array. As the size of the numpy array increases, increasing num_worker
or ‘prefetch_factor’ does not improve the time taken for running through a batch.
I’m guessing this is because each worker serializes the batch to send to the main process where it is de-serialized. As the data size increase, this process would take more time. However, I would think if the queue size is large enough (num_workers
, prefetch_factor
), at some point, there should be a break even point where each training step consumption of a batch would be accompanied by replenishment via one of the workers as I illustrated in the above example.
In the code below, when MyIterable
returns a small object (np array of size (10, 150)), increasing num_workers
helps as expected. But when the returned object is larger (np array of size (1000, 150)), num_workers
or prefetch_factor
does not do much.
# small np object
avg time per batch for num workers=0: 0.47068126868714444
avg time per batch for num workers=2: 0.20982365206225495
avg time per batch for num workers=4: 0.10560789656221914
avg time per batch for num workers=6: 0.07202646931250456
avg time per batch for num workers=8: 0.05311137337469063
# large np object
avg time per batch for num workers=0: 0.6090951558124971
avg time per batch for num workers=2: 0.4594530961876444
avg time per batch for num workers=4: 0.45023533212543043
avg time per batch for num workers=6: 0.3830978863124983
avg time per batch for num workers=8: 0.3811495694375253
Am I missing something here? Why doesn’t the data loader queue have enough buffer such that data loading is not the bottleneck?
Even if the serialization and de-serialization process would take longer for the latter case, I’d expect to have a large enough buffer where the consumption and replenishment rate of the batches are almost equal. Otherwise, what is the point of having prefetch_factor
.
If the code is behaving as expected, are there any other ways to pre-load the next n
batches in a buffer such that it is large enough and never depleted?
Thanks
import time
import torch
import numpy as np
from time import sleep
from torch.utils.data import DataLoader, IterableDataset
def collate_fn(records):
# some custom collation function
return records
class MyIterable(object):
def __init__(self, n):
self.n = n
self.i = 0
def __iter__(self):
return self
def __next__(self):
if self.i < self.n:
sleep(0.003125) # simulates data fetch time
# return np.random.random((10, 150)) # small data item
return np.random.random((1000, 150)) # large data item
else:
raise StopIteration
class MyIterableDataset(IterableDataset):
def __init__(self, n):
super(MyIterableDataset).__init__()
self.n = n
def __iter__(self):
return MyIterable(self.n)
def get_performance_metrics(num_workers):
ds = MyIterableDataset(n=10000)
if num_workers == 0:
dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=128, collate_fn=collate_fn)
else:
dl = torch.utils.data.DataLoader(ds, num_workers=num_workers, prefetch_factor=4, persistent_workers=True,
batch_size=128, collate_fn=collate_fn,
multiprocessing_context='spawn')
warmup = 5
times = []
t0 = time.perf_counter()
for i, batch in enumerate(dl):
sleep(0.05) # simulates train step
e = time.perf_counter()
if i >= warmup:
times.append(e - t0)
t0 = time.perf_counter()
if i >= 20:
break
print(f'avg time per batch for num workers={num_workers}: {sum(times) / len(times)}')
if __name__ == '__main__':
num_worker_options = [0, 2, 4, 6, 8]
for n in num_worker_options:
get_performance_metrics(n)