Does a torch.utils.data.Dataset
worker initiate fetching when its iterator is called? __iter__
? __next__
? Once the DataLoader collates the batch, does the worker automatically start pulling the next one or does it wait idly until the iterator is called again?
TLDR: At what point does it fetch? What triggers the fetch?
The workers should start loading the data once the iterator is created and preload the prefetch_factor*num_workers
batches. Afterwards each next
call will consume a batch and let the corresponding worker load the next batch.
Here is a small example showing this behavior:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(512, 1)
def __getitem__(self, idx):
worker_info = torch.utils.data.get_worker_info()
print("loading idx {} from worker {}".format(
idx, worker_info.id))
x = self.data[idx]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(dataset, num_workers=8, batch_size=2)
loader_iter = iter(loader) # preloading starts here
# with the default prefetch_factor of 2, 2*num_workers=16 batches will be preloaded
# the max index printed by __getitem__ is thus 31 (16*batch_size=32 samples loaded)
data = next(loader_iter) # this will consume a batch and preload the next one from a single worker to fill the queue
# batch_size=2 new samples should be loaded
1 Like
One more question while you’re here…
Suppose one of the workers modifies their own data instance:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(512, 1)
def __getitem__(self, idx):
worker_info = torch.utils.data.get_worker_info()
if worker_info.id == 2:
self.data[7, :] = 5
x = self.data[idx]
return x
def __len__(self):
return len(self.data)
Now from outside the worker, is there a way for me to get that worker’s version of the data? As in:
loader = DataLoader(dataset, num_workers=8, batch_size=2)
data_from_worker = loader.get_worker(id=2).data
Is this possible?
I’m not aware of a mechanism to access the workers directly besides their information in the Dataset
.