I have a torch.utils.data.IterableDataset
dataset that loops over files and generates batches. Neither num files nor how many batches in each file are known ahead of time, hence the need for IterableDataset.
at the beginning of dataset.__iter__
,
I set self.worker_id, self.num_workers torch.utils.data.get_worker_info().id
and use this information to split the files between workers, so that they don’t all just generate identical batches.
Each worker/dataset instance stores a file_idx and sample_idx so that they can be resumed if the job dies.
I want the dataloader to collect dictionaries of
{worker_id: dataset_instance.get_state()}
so I started passing worker states through Dataset.iter and collate_fn
but there is no guarantee that I have ever seen the state of every worker.
In my test, for example, it seems like every batch element of every batch is generated by worker 0, so the dataloader only persists the state of worker 0.
class StatefulDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._batch_num = 0
self.global_worker_states = {}
def __iter__(self):
iterator = super().__iter__()
return self._wrap_iterator(iterator)
def _wrap_iterator(self, iterator):
for batch, worker_state in iterator:
self.global_worker_states.update(worker_state)
print(f'Batch: {self._batch_num}: GlobalWorkerStates: {self.global_worker_states}')
yield batch, self.global_worker_states
self._batch_num += 1
Is there a mechanism to either:
a) guarantee that batches are composed of entries from every worker
b) call dataset.get_state()
on every worker from the dataloader
c) A more general pattern to save the state of every worker instance
My high level goal of using extra workers is to prefetch batches and hide IO latency from the training process.