When does a Pytorch Dataset worker pull new data?

The workers should start loading the data once the iterator is created and preload the prefetch_factor*num_workers batches. Afterwards each next call will consume a batch and let the corresponding worker load the next batch.
Here is a small example showing this behavior:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(512, 1)
        
    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        print("loading idx {} from worker {}".format(
            idx, worker_info.id))
        x = self.data[idx]
        return x
    
    def __len__(self):
        return len(self.data)


dataset = MyDataset()
loader = DataLoader(dataset, num_workers=8, batch_size=2)
loader_iter = iter(loader) # preloading starts here
# with the default prefetch_factor of 2, 2*num_workers=16 batches will be preloaded
# the max index printed by __getitem__ is thus 31 (16*batch_size=32 samples loaded)

data = next(loader_iter) # this will consume a batch and preload the next one from a single worker to fill the queue
# batch_size=2 new samples should be loaded
4 Likes