Dataloader: access state from each worker

I have a torch.utils.data.IterableDataset dataset that loops over files and generates batches. Neither num files nor how many batches in each file are known ahead of time, hence the need for IterableDataset.

at the beginning of dataset.__iter__,

I set self.worker_id, self.num_workers torch.utils.data.get_worker_info().id
and use this information to split the files between workers, so that they don’t all just generate identical batches.

Each worker/dataset instance stores a file_idx and sample_idx so that they can be resumed if the job dies.

I want the dataloader to collect dictionaries of

{worker_id: dataset_instance.get_state()}

so I started passing worker states through Dataset.iter and collate_fn

but there is no guarantee that I have ever seen the state of every worker.

In my test, for example, it seems like every batch element of every batch is generated by worker 0, so the dataloader only persists the state of worker 0.

class StatefulDataLoader(torch.utils.data.DataLoader):
    def __init__(self, *args,  **kwargs):
        super().__init__(*args, **kwargs)
        self._batch_num = 0
        self.global_worker_states = {}

    def __iter__(self):

        iterator = super().__iter__()
        return self._wrap_iterator(iterator)

    def _wrap_iterator(self, iterator):
        for batch, worker_state in iterator:
            self.global_worker_states.update(worker_state)
            print(f'Batch: {self._batch_num}: GlobalWorkerStates: {self.global_worker_states}')
            yield batch, self.global_worker_states
            self._batch_num += 1

Is there a mechanism to either:
a) guarantee that batches are composed of entries from every worker
b) call dataset.get_state() on every worker from the dataloader
c) A more general pattern to save the state of every worker instance

My high level goal of using extra workers is to prefetch batches and hide IO latency from the training process.

Hi Sam, please have a look at Stateful DataLoader in the torchdata repo. It’s a drop-in replacement for torch.utils.data.DataLoader and handles what you’re requesting :slight_smile:

We are currently preparing the next release of torchdata 0.8.0 to follow pytorch 2.4.0 and it will include this new StatefulDataLoader. Until then, you can install from the Nightly or 0.8.0rcX if they are available.