Queue in DataLoader does not behave as expected when using num_workers?

Do I understand the following correctly?

When num_workers >=1, the main process pre-loads prefetch_factor * num_workers batches. When the training loop consumes one batch, the corresponding worker loads the next batch in its queue.

If this is the case, let’s go through an example.

NOTE: I have chosen the numeric values for illustration purposes and
have ignored various overheads in this example. The accompanying code example uses these numbers.

Say I have num_workers=4, prefetch_factor=4, batch_size=128. Further assume, it takes 0.003125 s to fetch an item from a source database and the train step takes 0.05 s.

Now, each batch would take 0.003125 * 128 = 0.4 s to load.

With a prefetch_factor=4 and num_workers=4, first, 4*4=16 batches will be loaded.

Once the 16 batches are loaded, the first train step consumes 1 batch and takes 0.05 s. Say worker[0] provided this batch and will start the process to generate a new batch to replenish the queue. Recall fetching a new batch takes 0.4 s.

Similarly, the second step consumes one more batch and the corresponding worker (worker[1] in this example) starts the data fetching process.

The first 8 train steps would take 0.05*8=0.4s. By this time, 8 batches have been
consumed and worker[0] has produced 1 batch. In the next step, 1 batch is consumed and worker[1] produces a new batch. worker[1] had started the data fetching process in the second train step which would now be completed.

Following this we can see, each subsequent train step will consume 1 batch and one of the workers will produce 1 batch, keeping the dataloader queue to have always 8 batches. This means that the train step is never waiting for the data loading process as there are always 8 batches in the buffer.

I would expect this behavior regardless of the data size of the batch given num_workers, prefetch_factor are large enough. However, in the following code example that is not case.

In the code below, I define a custom iterable that returns a numpy array. As the size of the numpy array increases, increasing num_worker or ‘prefetch_factor’ does not improve the time taken for running through a batch.

I’m guessing this is because each worker serializes the batch to send to the main process where it is de-serialized. As the data size increase, this process would take more time. However, I would think if the queue size is large enough (num_workers, prefetch_factor), at some point, there should be a break even point where each training step consumption of a batch would be accompanied by replenishment via one of the workers as I illustrated in the above example.

In the code below, when MyIterable returns a small object (np array of size (10, 150)), increasing num_workers helps as expected. But when the returned object is larger (np array of size (1000, 150)), num_workers or prefetch_factor does not do much.

# small np object
avg time per batch for num workers=0: 0.47068126868714444
avg time per batch for num workers=2: 0.20982365206225495
avg time per batch for num workers=4: 0.10560789656221914
avg time per batch for num workers=6: 0.07202646931250456
avg time per batch for num workers=8: 0.05311137337469063
# large np object
avg time per batch for num workers=0: 0.6090951558124971
avg time per batch for num workers=2: 0.4594530961876444
avg time per batch for num workers=4: 0.45023533212543043
avg time per batch for num workers=6: 0.3830978863124983
avg time per batch for num workers=8: 0.3811495694375253

Am I missing something here? Why doesn’t the data loader queue have enough buffer such that data loading is not the bottleneck?

Even if the serialization and de-serialization process would take longer for the latter case, I’d expect to have a large enough buffer where the consumption and replenishment rate of the batches are almost equal. Otherwise, what is the point of having prefetch_factor.

If the code is behaving as expected, are there any other ways to pre-load the next n batches in a buffer such that it is large enough and never depleted?


import time
import torch
import numpy as np
from time import sleep
from torch.utils.data import DataLoader, IterableDataset

def collate_fn(records):
    # some custom collation function
    return records

class MyIterable(object):
    def __init__(self, n):
        self.n = n
        self.i = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.i < self.n:
            sleep(0.003125)  # simulates data fetch time
            # return np.random.random((10, 150))  # small data item
            return np.random.random((1000, 150))  # large data item
            raise StopIteration

class MyIterableDataset(IterableDataset):
    def __init__(self, n):
        self.n = n

    def __iter__(self):
        return MyIterable(self.n)

def get_performance_metrics(num_workers):
    ds = MyIterableDataset(n=10000)

    if num_workers == 0:
        dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=128, collate_fn=collate_fn)
        dl = torch.utils.data.DataLoader(ds, num_workers=num_workers, prefetch_factor=4, persistent_workers=True,
                                         batch_size=128, collate_fn=collate_fn,
    warmup = 5
    times = []
    t0 = time.perf_counter()
    for i, batch in enumerate(dl):
        sleep(0.05)  # simulates train step
        e = time.perf_counter()
        if i >= warmup:
            times.append(e - t0)
        t0 = time.perf_counter()
        if i >= 20:
    print(f'avg time per batch for num workers={num_workers}: {sum(times) / len(times)}')

if __name__ == '__main__':
    num_worker_options = [0, 2, 4, 6, 8]
    for n in num_worker_options: