Num_worker and prefetch_factor in DataLoader not scaling as expected

It seems like serialization and deserialization associated with python’s multiprocessing limit the benefits of processing data in parallel.

In the following example, I create a custom iterable that returns a numpy array. As the size of the numpy array increases, the data fetching process becomes the bottleneck. This is expected. However, I would expect increasing num_worker and prefetch_factor would reduce this bottleneck by preparing batches in advance. But I do not see this behavior in the example below.

I test two cases where MyIterable returns

  1. small object np.array((10, 150))
  2. large object np.array((1000, 150))

The average time to process a batch in both scenarios is as follows:

# small np object
avg time per batch for num workers=0: 0.47068126868714444
avg time per batch for num workers=2: 0.20982365206225495
avg time per batch for num workers=4: 0.10560789656221914
avg time per batch for num workers=6: 0.07202646931250456
avg time per batch for num workers=8: 0.05311137337469063
# large np object
avg time per batch for num workers=0: 0.6090951558124971
avg time per batch for num workers=2: 0.4594530961876444
avg time per batch for num workers=4: 0.45023533212543043
avg time per batch for num workers=6: 0.3830978863124983
avg time per batch for num workers=8: 0.3811495694375253

For the small object, the time for each batch drops as expected when num_workers are increased. But for larger object, it does not change much. I attribute it to the fact the the worker process has to serialize the np object and the main process would then deserialize it. The larger the object, the more time it will take.

However, with large enough num_worker and prefetch_factor, shouldn’t the queue in the dataloader be always filled such that data fetching is not the bottleneck?

Moreover, changing the prefetch_factor does not change anything. What is the point of prefetch_factor? The document says the main process pre-loads num_worker * prefetch_factor batches but as you can there is no effect in reducing the bottleneck.

I have added a more detailed step-by-step analysis in this question for reference.

import time
import torch
import numpy as np
from time import sleep
from torch.utils.data import DataLoader, IterableDataset


def collate_fn(records):
    # some custom collation function
    return records


class MyIterable(object):
    def __init__(self, n):
        self.n = n
        self.i = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.i < self.n:
            sleep(0.003125)  # simulates data fetch time
            # return np.random.random((10, 150))  # small data item
            return np.random.random((1000, 150))  # large data item
        else:
            raise StopIteration


class MyIterableDataset(IterableDataset):
    def __init__(self, n):
        super(MyIterableDataset).__init__()
        self.n = n

    def __iter__(self):
        return MyIterable(self.n)


def get_performance_metrics(num_workers):
    ds = MyIterableDataset(n=10000)

    if num_workers == 0:
        dl = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=128, collate_fn=collate_fn)
    else:
        dl = torch.utils.data.DataLoader(ds, num_workers=num_workers, prefetch_factor=4, persistent_workers=True,
                                         batch_size=128, collate_fn=collate_fn,
                                         multiprocessing_context='spawn')
    warmup = 5
    times = []
    t0 = time.perf_counter()
    for i, batch in enumerate(dl):
        sleep(0.05)  # simulates train step
        e = time.perf_counter()
        if i >= warmup:
            times.append(e - t0)
        t0 = time.perf_counter()
        if i >= 20:
            break
    print(f'avg time per batch for num workers={num_workers}: {sum(times) / len(times)}')


if __name__ == '__main__':
    num_worker_options = [0, 2, 4, 6, 8]
    for n in num_worker_options:
        get_performance_metrics(n)