Queue in DataLoader does not behave as expected when using num_workers?

Do I understand the following correctly?

When num_workers >=1, the main process pre-loads prefetch_factor * num_workers batches. When the training loop consumes one batch, the corresponding worker loads the next batch in its queue.

If this is the case, let’s go through an example.

NOTE: I have chosen the numeric values for illustration purposes and
have ignored various overheads in this example. The accompanying code example uses these numbers.

Say I have num_workers=4, prefetch_factor=4, batch_size=128. Further assume, it takes 0.003125 s to fetch an item from a source database and the train step takes 0.05 s.

Now, each batch would take 0.003125 * 128 = 0.4 s to load.

With a prefetch_factor=4 and num_workers=4, first, 4*4=16 batches will be loaded.

Once the 16 batches are loaded, the first train step consumes 1 batch and takes 0.05 s. Say worker[0] provided this batch and will start the process to generate a new batch to replenish the queue. Recall fetching a new batch takes 0.4 s.

Similarly, the second step consumes one more batch and the corresponding worker (worker[1] in this example) starts the data fetching process.

The first 8 train steps would take 0.05*8=0.4s. By this time, 8 batches have been
consumed and worker[0] has produced 1 batch. In the next step, 1 batch is consumed and worker[1] produces a new batch. worker[1] had started the data fetching process in the second train step which would now be completed.

Following this we can see, each subsequent train step will consume 1 batch and one of the workers will produce 1 batch, keeping the dataloader queue to have always 8 batches. This means that the train step is never waiting for the data loading process as there are always 8 batches in the buffer.

I would expect this behavior regardless of the data size of the batch given num_workers, prefetch_factor are large enough. However, in the following code example that is not case.

In the code below, I define a custom iterable that returns a numpy array. As the size of the numpy array increases, increasing num_worker or ‘prefetch_factor’ does not improve the time taken for running through a batch.

I’m guessing this is because each worker serializes the batch to send to the main process where it is de-serialized. As the data size increase, this process would take more time. However, I would think if the queue size is large enough (num_workers, prefetch_factor), at some point, there should be a break even point where each training step consumption of a batch would be accompanied by replenishment via one of the workers as I illustrated in the above example.

In the code below, when MyIterable returns a small object (np array of size (10, 150)), increasing num_workers helps as expected. But when the returned object is larger (np array of size (1000, 150)), num_workers or prefetch_factor does not do much.

# small np object
avg time per batch for num workers=0: 0.47068126868714444
avg time per batch for num workers=2: 0.20982365206225495
avg time per batch for num workers=4: 0.10560789656221914
avg time per batch for num workers=6: 0.07202646931250456
avg time per batch for num workers=8: 0.05311137337469063
# large np object
avg time per batch for num workers=0: 0.6090951558124971
avg time per batch for num workers=2: 0.4594530961876444
avg time per batch for num workers=4: 0.45023533212543043
avg time per batch for num workers=6: 0.3830978863124983
avg time per batch for num workers=8: 0.3811495694375253

Am I missing something here? Why doesn’t the data loader queue have enough buffer such that data loading is not the bottleneck?

Even if the serialization and de-serialization process would take longer for the latter case, I’d expect to have a large enough buffer where the consumption and replenishment rate of the batches are almost equal. Otherwise, what is the point of having prefetch_factor.

If the code is behaving as expected, are there any other ways to pre-load the next n batches in a buffer such that it is large enough and never depleted?

Thanks

import time
import torch
import numpy as np
from time import sleep
from torch.utils.data import DataLoader, IterableDataset


def collate_fn(records):
    # some custom collation function
    return records


class MyIterable(object):
    def __init__(self, n):
        self.n = n
        self.i = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.i < self.n:
            sleep(0.003125)  # simulates data fetch time
            # return np.random.random((10, 150))  # small data item
            return np.random.random((1000, 150))  # large data item
        else:
            raise StopIteration


class MyIterableDataset(IterableDataset):
    def __init__(self, n):
        super(MyIterableDataset).__init__()
        self.n = n

    def __iter__(self):
        return MyIterable(self.n)


def get_performance_metrics(num_workers):
    ds = MyIterableDataset(n=10000)

    if num_workers == 0:
        dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=128, collate_fn=collate_fn)
    else:
        dl = torch.utils.data.DataLoader(ds, num_workers=num_workers, prefetch_factor=4, persistent_workers=True,
                                         batch_size=128, collate_fn=collate_fn,
                                         multiprocessing_context='spawn')
    warmup = 5
    times = []
    t0 = time.perf_counter()
    for i, batch in enumerate(dl):
        sleep(0.05)  # simulates train step
        e = time.perf_counter()
        if i >= warmup:
            times.append(e - t0)
        t0 = time.perf_counter()
        if i >= 20:
            break
    print(f'avg time per batch for num workers={num_workers}: {sum(times) / len(times)}')


if __name__ == '__main__':
    num_worker_options = [0, 2, 4, 6, 8]
    for n in num_worker_options:
        get_performance_metrics(n)

I have tested it also I think the answer is here
where as written in the documentation :slight_smile:

prefetch_factor (int, optional , keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default: 2)

but there is try in the emplementation this indicates that it will iterate until 2 * num_workers so this is themaximum number if the core is busy then the total number will be less than that
according to the function _try_put_index in the implementation

@Arij-Aladel I dont follow, could you please elaborate? _try_put_index doesnot have any logic for iterating until 2*no_workers as you mentioned.

@lkc1 where you able to figure this out ?

So the problem isn’t Pytorch, it is with how multiprocessing works in python which Pytorch uses. In python multiprocessing, all the data objects in the parent process are serialized to be passed to the child process. If the size of the parent objects is not large, the overhead is not a problem and you see the advantage of using multiple processes in total time. However, if the data objects in the parent process are large (for my case it was ~100-200 MB), the serialization process adds overhead and thereby neutralizing any benefits one might get with multiprocessing.

So, I ended up with a somewhat complicated solution of creating the data loader in C++ and feeding that to Pytorch’s dataloader. That worked pretty well with respect to latency.

1 Like

@lkc1 Thanks!
Got it. I was able to reproduce the same with Jetson Orin 12 CPUs. Here are the results ;

Small obj case

avg time per batch for num workers=0: 0.459348852338735
avg time per batch for num workers=2: 0.2055776757770218
avg time per batch for num workers=4: 0.10303598642349243
avg time per batch for num workers=6: 0.07063477154588327
avg time per batch for num workers=8: 0.058643676224164665

Large obj case :

avg time per batch for num workers=0: 0.6697914574178867
avg time per batch for num workers=2: 0.44152931112330407
avg time per batch for num workers=4: 0.305355457123369
avg time per batch for num workers=6: 0.39800371177261695
avg time per batch for num workers=8: 0.3520681524532847

It would help me a lot to know where exactly the seralization/de-seralization happen in the mutiprocessing module in python. It would be great if you could point me to the lines of code. I would need to dig deeper into the exact overhead time caused by the seralization/de-serailization in the latter case when there are large objects. So, please let me know how can I achieve this? Thanks a lot !

@Vinayaka_Hegde

Didyou read the code?
# prime the prefetch loop
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()
then you have:

def _try_put_index(self):
    assert self._tasks_outstanding < self._prefetch_factor * self._num_workers

    try:
        index = self._next_index()
    except StopIteration:
        return
    for _ in range(self._num_workers):  # find the next active worker, if any
        worker_queue_idx = next(self._worker_queue_idx_cycle)
        if self._workers_status[worker_queue_idx]:
            break
    else:
        # not found (i.e., didn't break)
        return

    self._index_queues[worker_queue_idx].put((self._send_idx, index))
    self._task_info[self._send_idx] = (worker_queue_idx,)
    self._tasks_outstanding += 1
    self._send_idx += 1

Hope it is clear now