My data is stored in class-specific files, each file contains all the data for one class.
I’m currently loading it with a custom IterableDataset and a multi-worker DataLoader. As my dataset is quite large, each worker is responsible for a fixed set of
n_files/n_workers files as I do not want to every file into memory on every worker.
The problem is as each file is as class-specific, each worker only produces a batch containing the classes it has been assigned (based on worker ID). Each worker has it’s own copy of
collate_fn so it batches within the worker.
How to shuffle an iterable dataset discusses how to shuffle using
torch.utils.data.datapipes.iter.combinatorics.ShuffleIterDataPipe (which isn’t in the docs?). It applies to any iterator but applying it to
Dataset then shuffles within each worker and thus batch or to the
DataLoader which then only shuffles the order the batches are yielded. The
shuffle flag in the
DataLoader also just throws me an error about not selecting mode.
Is there some argument in
DataLoader that can let me mix and re-collect some buffer of completed batches? Ideally I don’t want to just iterate and collect some
n_worker batches and manually shuffle them.
Here’s a simple example of what I’m doing
class CustomDataset(IterableDataset): def __init__(self, files: List[str]): self.files = files def __iter__(self): worker_info = torch.utils.data.get_worker_info() if worker_info is None: files_chunk = self.files else: n_workers = worker_info.num_workers n_files = len(self.files) chunk_size = n_files // n_workers chunk_start = chunk_size * worker_info.id files_chunk = self.files[chunk_start: chunk_start + chunk_size] files_data = [open(f, 'r') for f in files_chunk] for line in chain.from_iterable(zip(*files_data)): yield line
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=n_cpus, pin_memory=True)