When does a Pytorch Dataset worker pull new data?

Does a torch.utils.data.Dataset worker initiate fetching when its iterator is called? __iter__? __next__? Once the DataLoader collates the batch, does the worker automatically start pulling the next one or does it wait idly until the iterator is called again?

TLDR: At what point does it fetch? What triggers the fetch?

The workers should start loading the data once the iterator is created and preload the prefetch_factor*num_workers batches. Afterwards each next call will consume a batch and let the corresponding worker load the next batch.
Here is a small example showing this behavior:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(512, 1)
        
    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        print("loading idx {} from worker {}".format(
            idx, worker_info.id))
        x = self.data[idx]
        return x
    
    def __len__(self):
        return len(self.data)


dataset = MyDataset()
loader = DataLoader(dataset, num_workers=8, batch_size=2)
loader_iter = iter(loader) # preloading starts here
# with the default prefetch_factor of 2, 2*num_workers=16 batches will be preloaded
# the max index printed by __getitem__ is thus 31 (16*batch_size=32 samples loaded)

data = next(loader_iter) # this will consume a batch and preload the next one from a single worker to fill the queue
# batch_size=2 new samples should be loaded
4 Likes

One more question while you’re here…

Suppose one of the workers modifies their own data instance:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(512, 1)
        
    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info.id == 2:
            self.data[7, :] = 5  
        x = self.data[idx]
        return x
    
    def __len__(self):
        return len(self.data)

Now from outside the worker, is there a way for me to get that worker’s version of the data? As in:

loader = DataLoader(dataset, num_workers=8, batch_size=2)
data_from_worker = loader.get_worker(id=2).data

Is this possible?

I’m not aware of a mechanism to access the workers directly besides their information in the Dataset.

Thanks for the answer above. I wanted to know where exactly in the dataloader.py “the next call would consume the current batch and load the next batch” ?

As far as I understood, the __next__ function calls the `_next_data.

The _next-data calls the _process_data

and finally, in the process_data method, the try_put_index is being called which fetches the next index from sampler and gets the next active worker and assigns the batch to the worker.
image

On the other hand, for consuming the current index, the worker.py has a line to perform a indexqueue.get that gets the current index from the queue and loads the corresponding batch using fetcher.fetch.

Can you please confirm if I understand this properly ?

The description of the call chain sounds reasonable based on the posted screenshots of the code, but to be sure I would recommend to verify it by either stepping through it with a debugger or by e.g. adding debug print statements to see how the execution is done.
Also related to this thread.

Does this mean that loader_iter = iter(loader) only needs to be called once in a training loop because each time a batch is consumed via the call next(loader_iter), a batch will be pre-loaded to maintain the number of pre-loaded batches?

Yes, that’s right but note that you would need to recreate the iterator once it’s exhausted unless you are setting persistent_workers=True.

1 Like