My data is stored in class-specific files, each file contains all the data for one class.
I’m currently loading it with a custom IterableDataset and a multi-worker DataLoader. As my dataset is quite large, each worker is responsible for a fixed set of n_files/n_workers
files as I do not want to every file into memory on every worker.
The problem is as each file is as class-specific, each worker only produces a batch containing the classes it has been assigned (based on worker ID). Each worker has it’s own copy of dataset
and collate_fn
so it batches within the worker.
How to shuffle an iterable dataset discusses how to shuffle using torch.utils.data.datapipes.iter.combinatorics.ShuffleIterDataPipe
(which isn’t in the docs?). It applies to any iterator but applying it to Dataset
then shuffles within each worker and thus batch or to the DataLoader
which then only shuffles the order the batches are yielded. The shuffle
flag in the DataLoader
also just throws me an error about not selecting mode.
Is there some argument in DataLoader
that can let me mix and re-collect some buffer of completed batches? Ideally I don’t want to just iterate and collect some n_worker
batches and manually shuffle them.
Thanks!
Here’s a simple example of what I’m doing
class CustomDataset(IterableDataset):
def __init__(self, files: List[str]):
self.files = files
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
files_chunk = self.files
else:
n_workers = worker_info.num_workers
n_files = len(self.files)
chunk_size = n_files // n_workers
chunk_start = chunk_size * worker_info.id
files_chunk = self.files[chunk_start: chunk_start + chunk_size]
files_data = [open(f, 'r') for f in files_chunk]
for line in chain.from_iterable(zip(*files_data)):
yield line
dataloader = DataLoader(dataset,
batch_size=args.batch_size,
num_workers=n_cpus,
pin_memory=True)