I am trying to understand how workers are initialized, what data is copied/shared and in what order they are called.
I have the following snippet:
import torch
from torch.utils.data import DataLoader, Dataset
def worker_init_fn(worker_id):
info = torch.utils.data.get_worker_info()
print(id(info.dataset), worker_id)
class DummyDataset(Dataset):
def __init__(self):
self.samples = list(range(8))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
print(f'Worker {torch.utils.data.get_worker_info().id} is used')
return self.samples[idx]
ds = DummyDataset()
train_loader = DataLoader(
ds, num_workers=4,
worker_init_fn=worker_init_fn, batch_size=2,
)
for i, data in enumerate(train_loader):
pass
which outputs:
140276885599920 0
140276885599920 1
Worker 0 is used
Worker 0 is used
140276885599920 2
Worker 2 is used
Worker 2 is used
Worker 1 is used
Worker 1 is used
140276885599920 3
Worker 3 is used
Worker 3 is used
- Workers are not initialized simultaneously. As it can be seen, before initializing worker 2, worker 0 is already used. Is this correct?
- Do they make a shallow or deep copy the dataset? From the output, it seems that each worker is passed the same dataset. However, I am not sure, since it can be the case that each previously dataset is garbage collected and the new one just happens to have the same
id
as the previous one.