Unexpected parallel data loader performance using IterableDatasets compared to (map-style) Datasets with num_workers > 1


I’m trying to diagnose a performance discrepancy between using IterableDatasets and (map-style)Datasets in a multi-processed data loader setting.

My experiment (code at the end of the post) consisted of:

  1. Make a (Iterable)Dataset. This synthesizes dummy data and possibly adds a time delay to simulate batch loading work.
  2. Make a Dataloader that consumes the above dataset instance. The number of workers varied from 0 (loading in the main process) to 4.
  3. Iterate through 100 batches of data yielded by the dataloader, with possibly a time delay to simulate batch processing work (e.g. training).

I varied the number of workers, batch loading time, and batch processing time. I plotted my results as

IterableDataset (left) vs (map-style)Dataset (right)

The vertical axis is time to run through N=100 batches divided by the time to do so with num_workers = 0 in each case. I grouped the results by batch processing / loading time.

I expected the behavior of (map-style)Datasets, but I did not expect a linear degradation in performance for num_workers >= 1 when using IterableDatasets.

Could someone please have a look at my code below and let me know if I’m using IterableDataset incorrectly or if there might be a bug we should know about/report?


import os
import time
import itertools
import numpy as np
import torch

class TimedBlock:
    Context manager to measure wall time.

    def __init__(self, name):
        self.name = name

    def __enter__(self):
        self.t_start = time.perf_counter()

    def __exit__(self, *args):
        print("(%d) %s execution finished. Time elapsed = %.3fs." % (os.getpid(), self.name,
                                                                     time.perf_counter() - self.t_start))

class IterableDataset(torch.utils.data.IterableDataset):
    Synthetic iterable dataset that simulates data loading work.

    def __init__(self, n_items, shape, batch_loading_time):
            n_items: number of items to return.
            shape: shape of batch tensor to return.
            batch_loading_time: each batch will take at least this many seconds to return.
        self.n_items = n_items
        self.shape = shape
        self.batch_loading_time = batch_loading_time

    def __iter__(self):
        self.i = 0
        while self.i < self.n_items:
            t_start = time.perf_counter()
            X = np.random.randn(*self.shape)
            t_remaining =  self.batch_loading_time - (time.perf_counter() - t_start)
            if t_remaining > 0:
            yield X
            self.i += 1

class Dataset(torch.utils.data.Dataset):
    Synthetic map-style dataset that simulates data loading work.

    def __init__(self, n_items, shape, batch_loading_time):
            n_items: number of items to return.
            shape: shape of batch tensor to return.
            batch_loading_time: each batch will take at least this many seconds to return.
        self.n_items = n_items
        self.shape = shape
        self.batch_loading_time = batch_loading_time

    def __len__(self):
        return self.n_items

    def __getitem__(self, idx):
        t_start = time.perf_counter()
        X = np.random.randn(*self.shape)
        t_remaining =  self.batch_loading_time - (time.perf_counter() - t_start)
        if t_remaining > 0:
        return X

def test_simpleloader(n_iters=int(1e2), shape=(1,), num_workers=0, batch_loading_time=0, batch_process_time=0,
                      exp_name="simpleloader", dataset_cls=IterableDataset):
    Load and process n_iters batches of data using a dataloader and one of the above dataset classes. We simulate
    training by waiting batch_process_time seconds per batch.

    dataset = dataset_cls(n_iters, shape, batch_loading_time)
    data_loader = torch.utils.data.DataLoader(dataset,
    with TimedBlock(exp_name):
        for i, x in enumerate(data_loader):

def simpleloader_grid(num_workers_grid=[0, 1, 2, 3, 4],
                      batch_loading_time_grid=[10**i for i in range(0, -2, -1)],
                      batch_process_time_grid=[10**i for i in range(0, -2, -1)],

    n_iters = int(1e2)
    shape = (1,)

    params = itertools.product(num_workers_grid, batch_loading_time_grid, batch_process_time_grid)

    for p in params:
        test_simpleloader(n_iters, shape, p[0], p[1], p[2], str(p), dataset_cls=dataset_cls)

if __name__ == "__main__":
    torch.multiprocessing.set_start_method("spawn")  # probably not important here
    t_start = time.perf_counter()
    print("Main execution time = %.3fs" % (time.perf_counter() - t_start))