When does a Pytorch Dataset worker pull new data?

Does a torch.utils.data.Dataset worker initiate fetching when its iterator is called? __iter__? __next__? Once the DataLoader collates the batch, does the worker automatically start pulling the next one or does it wait idly until the iterator is called again?

TLDR: At what point does it fetch? What triggers the fetch?

The workers should start loading the data once the iterator is created and preload the prefetch_factor*num_workers batches. Afterwards each next call will consume a batch and let the corresponding worker load the next batch.
Here is a small example showing this behavior:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(512, 1)
        
    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        print("loading idx {} from worker {}".format(
            idx, worker_info.id))
        x = self.data[idx]
        return x
    
    def __len__(self):
        return len(self.data)


dataset = MyDataset()
loader = DataLoader(dataset, num_workers=8, batch_size=2)
loader_iter = iter(loader) # preloading starts here
# with the default prefetch_factor of 2, 2*num_workers=16 batches will be preloaded
# the max index printed by __getitem__ is thus 31 (16*batch_size=32 samples loaded)

data = next(loader_iter) # this will consume a batch and preload the next one from a single worker to fill the queue
# batch_size=2 new samples should be loaded
1 Like

One more question while you’re here…

Suppose one of the workers modifies their own data instance:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(512, 1)
        
    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info.id == 2:
            self.data[7, :] = 5  
        x = self.data[idx]
        return x
    
    def __len__(self):
        return len(self.data)

Now from outside the worker, is there a way for me to get that worker’s version of the data? As in:

loader = DataLoader(dataset, num_workers=8, batch_size=2)
data_from_worker = loader.get_worker(id=2).data

Is this possible?

I’m not aware of a mechanism to access the workers directly besides their information in the Dataset.