I’m trying to diagnose a performance discrepancy between using IterableDatasets and (map-style)Datasets in a multi-processed data loader setting.
My experiment (code at the end of the post) consisted of:
- Make a (Iterable)Dataset. This synthesizes dummy data and possibly adds a time delay to simulate batch loading work.
- Make a Dataloader that consumes the above dataset instance. The number of workers varied from 0 (loading in the main process) to 4.
- Iterate through 100 batches of data yielded by the dataloader, with possibly a time delay to simulate batch processing work (e.g. training).
I varied the number of workers, batch loading time, and batch processing time. I plotted my results as
IterableDataset (left) vs (map-style)Dataset (right)
The vertical axis is time to run through N=100 batches divided by the time to do so with num_workers = 0 in each case. I grouped the results by batch processing / loading time.
I expected the behavior of (map-style)Datasets, but I did not expect a linear degradation in performance for num_workers >= 1 when using IterableDatasets.
Could someone please have a look at my code below and let me know if I’m using IterableDataset incorrectly or if there might be a bug we should know about/report?
import os import time import itertools import numpy as np import torch class TimedBlock: """ Context manager to measure wall time. """ def __init__(self, name): self.name = name def __enter__(self): self.t_start = time.perf_counter() def __exit__(self, *args): print("(%d) %s execution finished. Time elapsed = %.3fs." % (os.getpid(), self.name, time.perf_counter() - self.t_start)) class IterableDataset(torch.utils.data.IterableDataset): """ Synthetic iterable dataset that simulates data loading work. """ def __init__(self, n_items, shape, batch_loading_time): """ Args: n_items: number of items to return. shape: shape of batch tensor to return. batch_loading_time: each batch will take at least this many seconds to return. """ self.n_items = n_items self.shape = shape self.batch_loading_time = batch_loading_time def __iter__(self): self.i = 0 while self.i < self.n_items: t_start = time.perf_counter() X = np.random.randn(*self.shape) t_remaining = self.batch_loading_time - (time.perf_counter() - t_start) if t_remaining > 0: time.sleep(t_remaining) yield X self.i += 1 class Dataset(torch.utils.data.Dataset): """ Synthetic map-style dataset that simulates data loading work. """ def __init__(self, n_items, shape, batch_loading_time): """ Args: n_items: number of items to return. shape: shape of batch tensor to return. batch_loading_time: each batch will take at least this many seconds to return. """ self.n_items = n_items self.shape = shape self.batch_loading_time = batch_loading_time def __len__(self): return self.n_items def __getitem__(self, idx): t_start = time.perf_counter() X = np.random.randn(*self.shape) t_remaining = self.batch_loading_time - (time.perf_counter() - t_start) if t_remaining > 0: time.sleep(t_remaining) return X def test_simpleloader(n_iters=int(1e2), shape=(1,), num_workers=0, batch_loading_time=0, batch_process_time=0, exp_name="simpleloader", dataset_cls=IterableDataset): """ Load and process n_iters batches of data using a dataloader and one of the above dataset classes. We simulate training by waiting batch_process_time seconds per batch. """ dataset = dataset_cls(n_iters, shape, batch_loading_time) data_loader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=num_workers, collate_fn=None, pin_memory=False, worker_init_fn=None, multiprocessing_context=None) with TimedBlock(exp_name): for i, x in enumerate(data_loader): time.sleep(batch_process_time) def simpleloader_grid(num_workers_grid=[0, 1, 2, 3, 4], batch_loading_time_grid=[10**i for i in range(0, -2, -1)], batch_process_time_grid=[10**i for i in range(0, -2, -1)], dataset_cls=IterableDataset): n_iters = int(1e2) shape = (1,) params = itertools.product(num_workers_grid, batch_loading_time_grid, batch_process_time_grid) for p in params: test_simpleloader(n_iters, shape, p, p, p, str(p), dataset_cls=dataset_cls) if __name__ == "__main__": torch.multiprocessing.set_start_method("spawn") # probably not important here t_start = time.perf_counter() simpleloader_grid(dataset_cls=IterableDataset) print("Main execution time = %.3fs" % (time.perf_counter() - t_start))