if I have multiple workers, would each worker collect consecutive indexes?
For example, if I have num_workers set to 2 , prefetch_factor set to 2 and batch_size set to 10 then, would one worker collect the first 2 batches (idx 0 to 6) and the second worker the following 2 batches (idx 6 to 12)?
Each worker reads all the samples unless you specified differently in worker_init_fn
.
See this note:
When used in a worker_init_fn
passed over to DataLoader
, this method can be useful to set up each worker process differently, for instance, using worker_id
to configure the dataset
object to only read a specific fraction of a sharded dataset, or use seed
to seed other libraries used in dataset code.
Yes, your description is correct but each worker would collect 10 samples as seen here:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.arange(20).view(-1, 1)
def __getitem__(self, idx):
worker = torch.utils.data.get_worker_info()
if worker:
print("worker id {}, idx {}".format(worker.id, idx))
x = self.data[idx]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
batch_size = 10
loader = DataLoader(dataset, batch_size, num_workers=2, prefetch_factor=2)
for data in loader:
print(data.shape)
# worker id 0, idx 0
# worker id 0, idx 1
# worker id 0, idx 2
# worker id 0, idx 3
# worker id 0, idx 4
# worker id 0, idx 5
# worker id 0, idx 6
# worker id 0, idx 7
# worker id 0, idx 8
# worker id 0, idx 9
# worker id 1, idx 10
# worker id 1, idx 11
# worker id 1, idx 12
# worker id 1, idx 13
# worker id 1, idx 14
# worker id 1, idx 15
# worker id 1, idx 16
# worker id 1, idx 17
# worker id 1, idx 18
# worker id 1, idx 19
# torch.Size([10, 1])
# torch.Size([10, 1])