Multiprocessing DataLoader is not faster for IterableDataset


I try to chain multiple IterableDatasets like Torch datapipes, and want to use multiprocessing DataLoader for acceleration but the improvement is very small at the cost of more memory footprint. The code below is an example. In my actual code there are more chained IterableDatasets where each NumpyDS is loaded from a file and some transformations are applied. Can anyone please help me understand if I use DataLoader correctly and explain what’s wrong with using it in this way? I am testing it with Torch 2.0.0 on an Amazon p3.2xlarge instance. Thanks in advance!

Create fake data files

import numpy as np
import pandas as pd

root_path = '/home/ubuntu/test-data'
num_files = 16
num_samples_per_file = 6000
num_features = 60
col_names = [f'col_{i}' for i in range(num_features)]

for i in range(num_files):
    data = np.random.rand(num_samples_per_file, 208, num_features).astype('float32')
    df = pd.DataFrame(data.reshape((-1, num_features)), columns=col_names)
    df.to_parquet(f'{root_path}/{i}.parquet', index=False, engine='pyarrow')

:~/test-data$ ls
0.parquet  10.parquet  12.parquet  14.parquet  2.parquet  4.parquet  6.parquet  8.parquet
1.parquet  11.parquet  13.parquet  15.parquet  3.parquet  5.parquet  7.parquet  9.parquet

Load data with multiple workers


import os
import sys
import time
import torch
from import IterableDataset, DataLoader
import pyarrow.parquet as pq

root_path = '/home/ubuntu/test-data'
num_files = 16
num_samples_per_file = 6000
num_features = 60
col_names = [f'col_{i}' for i in range(num_features)]

class FileLister(IterableDataset):
    def __init__(self, path):
        self.path = path
        self.files = [f'{path}/{p}' for p in os.listdir(self.path) if p.endswith('.parquet')]

    def __iter__(self):
        info =
        if info is None:
            yield from self.files
            yield from self.files[]

class FileReader(IterableDataset):
    def __init__(self, paths):
        self.paths = paths

    def __iter__(self):
        for p in self.paths:
            data = pq.read_table(p).to_pandas().to_numpy().astype('float32').reshape((-1, 208, 60))
            yield data

class Batcher(IterableDataset):
    def __init__(self, datasets, batch_size):
        self.datasets = datasets
        self.batch_size = batch_size

    def __iter__(self):
        for ds in self.datasets:
            for i in range(0, len(ds), self.batch_size):
                end = min(len(ds), i + self.batch_size)
                yield ds[i:end]

def batchify(data):
    return data[0]

if __name__ == '__main__':
    lister = FileLister(root_path)
    reader = FileReader(lister)
    batcher = Batcher(reader, 512)
    num_workers = int(sys.argv[1])
    persistent_workers = num_workers > 0
    dl = DataLoader(batcher, num_workers=num_workers, drop_last=False, batch_size=1, collate_fn=batchify,
    count = 0
    start = time.time()
    for b in dl:
        count += 1
    print(f'num_workers: {dl.num_workers}')
    print(f'iterate over {count} batches in {time.time()-start} seconds')
single process:
$ python 0
num_workers: 0
iterate over 192 batches in 5.767262697219849 seconds

$ python 1
num_workers: 1
iterate over 192 batches in 26.10194158554077 seconds

$ python 2
num_workers: 2
iterate over 192 batches in 15.331105470657349 seconds

$ python 4
num_workers: 4
iterate over 192 batches in 14.00988245010376 seconds

$ python 8
num_workers: 8
iterate over 192 batches in 14.190296411514282 seconds
1 Like

Using multiple workers to create a fake dataset might not yield the expected performance improvement since the actual data loading and processing is missing (this was discussed and profiled a long time ago already), so I would recommend profiling real workloads to see if multiple workers could properly prefetch the samples.

@ptrblck Thanks a lot for your replies! Updated the example code above. I created some test files and stored them in parquet format, which is the data format in my actual data. Multi-worker data loading is even slower than single process loading. Can you please share any thoughts about what I am doing wrong?

Thanks for sharing your example! Can you add print statements inside the iterable dataset to check whether multiprocessing is happening in parallel or not?