Enabling out-of-order batching in a multiprocess DataLoader?

I am working with a dataset where samples take uneven amounts of time to load. Most samples load on the order of 20ms while some samples take much longer (i.e. 10 seconds). I have a custom Dataset and I’m using a DataLoader to parallelize the loading process.

I’ve noticed that even when I set shuffle=True, the data loader will block waiting for certain workers. In an optimal situation, you would expect all available workers to be processing an “expensive” sample at any given moment but what I observe is that the workers are underutilized. A worker that processed a fast sample may be blocked on a worker that is processing a slow sample.

I can implement my own solution using multiprocessing.Pool.imap_unordered but I wanted to see if there is a canonical PyTorch way to solve this problem.

1 Like

I have some example code that shows the speed difference between the existing torch DataLoader and a custom imap_unordered solution:

import time
import multiprocessing
import numpy as np

from torch.utils.data import DataLoader, Dataset

class DemoDataset(object):
    def __len__(self):
        return 1000
    def __getitem__(self, idx):
        if idx % 10 == 0:
            print('wait', idx)
            print('done', idx)
        return idx

class FastLoader(object):
    def __init__(self, ds, num_workers, batch_size, shuffle):
        self.ds = ds
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.shuffle = shuffle

    def _process(self, batch):
        return [self.ds[x] for x in batch]

    def __iter__(self):
        work = list(range(len(self.ds)))
        if self.shuffle:
        batches = []
        for i in range(0, len(work), self.batch_size):

        with multiprocessing.Pool(self.num_workers) as p:
            for item in p.imap_unordered(self._process, batches):
                yield item

def test(cls):
    a = time.time()
    ds = DemoDataset()
    data = cls(ds, num_workers=4, batch_size=4, shuffle=True)

    for item in iter(data):

    b = time.time()

    return b - a

if __name__=='__main__':
    fast = test(FastLoader)
    base = test(DataLoader)
    print('fast', fast)
    print('base', base)

In this example, the “fast” loader is nearly twice as fast since most of the time you have all available workers fully utilized on the slow samples.