How to shuffle multi-worker IterableDataset loading

My data is stored in class-specific files, each file contains all the data for one class.
I’m currently loading it with a custom IterableDataset and a multi-worker DataLoader. As my dataset is quite large, each worker is responsible for a fixed set of n_files/n_workers files as I do not want to every file into memory on every worker.

The problem is as each file is as class-specific, each worker only produces a batch containing the classes it has been assigned (based on worker ID). Each worker has it’s own copy of dataset and collate_fn so it batches within the worker.

How to shuffle an iterable dataset discusses how to shuffle using (which isn’t in the docs?). It applies to any iterator but applying it to Dataset then shuffles within each worker and thus batch or to the DataLoader which then only shuffles the order the batches are yielded. The shuffle flag in the DataLoader also just throws me an error about not selecting mode.

Is there some argument in DataLoader that can let me mix and re-collect some buffer of completed batches? Ideally I don’t want to just iterate and collect some n_worker batches and manually shuffle them.


Here’s a simple example of what I’m doing

class CustomDataset(IterableDataset):
    def __init__(self, files: List[str]):
        self.files = files 
    def __iter__(self):
        worker_info =
        if worker_info is None:
            files_chunk = self.files
            n_workers = worker_info.num_workers
            n_files = len(self.files)
            chunk_size = n_files // n_workers
            chunk_start = chunk_size *
            files_chunk = self.files[chunk_start: chunk_start + chunk_size]
        files_data = [open(f, 'r') for f in files_chunk]
        for line in chain.from_iterable(zip(*files_data)):
            yield line
    dataloader = DataLoader(dataset,

I’ve managed to find a solution that works, even if it is a bit ugly.

  dataloader = DataLoader(dataset,
                          collate_fn=lambda batch: {k:v[0] for k,v in default_collate(batch).items()},
  shuffled = combinatorics.ShufflerIterDataPipe(dataloader, buffer_size=2 * args.batch_size)
  dataloader = DataLoader(shuffled,

The lambda is to reset the extra dimension added to the front when the first dataloader runs with a batch_size=1. ShufflerIterDataPipe then shuffles the single instances before they are batched by the second data loader. In this case, by running it with num_workers=0 I can also run GPU operations in there, though I’ve found it to conflict with pin_memory=True.

I’ve found an interesting solution to shuffle across batches by subclassing the dataloader:

class BufferShufflingDataLoader(DataLoader):
‘’’ subclasses the pytorch dataloader class to shuffle
the dataset across workers with a buffer size of num workers’‘’

def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.multi_buffer_size = kwargs['num_workers']
    self.multi_buffer = []

def __iter__(self):
    super_iter = super().__iter__()
    for i in range(self.multi_buffer_size):

    return self.multi_buffer_iterator(super_iter)

def multi_buffer_iterator(self, super_iter):
    for i in super_iter:
        x = torch.stack(self.multi_buffer)[:, 0]
        _ = self.multi_buffer.pop(0)
        yield x

It is not super efficient in this form bu could be easily sped up by changing the multi buffer list to a tensor of fixed size.