How workers work under the hood?

I am trying to understand how workers are initialized, what data is copied/shared and in what order they are called.

I have the following snippet:

import torch
from torch.utils.data import DataLoader, Dataset


def worker_init_fn(worker_id):
    info = torch.utils.data.get_worker_info()
    print(id(info.dataset), worker_id)


class DummyDataset(Dataset):
    def __init__(self):
        self.samples = list(range(8))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        print(f'Worker {torch.utils.data.get_worker_info().id} is used')
        return self.samples[idx]


ds = DummyDataset()
train_loader = DataLoader(
        ds, num_workers=4,
        worker_init_fn=worker_init_fn, batch_size=2,
        )

for i, data in enumerate(train_loader):
    pass

which outputs:

140276885599920 0                                                                                                                                                                             
140276885599920 1                                                                                                                                                                             
Worker 0 is used                                                                                                                                                                              
Worker 0 is used                                                                                                                                                                              
140276885599920 2                                                                                                                                                                             
Worker 2 is used                                                                                                                                                                              
Worker 2 is used                                                                                                                                                                              
Worker 1 is used                                                                                                                                                                              
Worker 1 is used                                                                                                                                                                              
140276885599920 3                                                                                                                                                                             
Worker 3 is used                                                                                                                                                                              
Worker 3 is used 
  • Workers are not initialized simultaneously. As it can be seen, before initializing worker 2, worker 0 is already used. Is this correct?
  • Do they make a shallow or deep copy the dataset? From the output, it seems that each worker is passed the same dataset. However, I am not sure, since it can be the case that each previously dataset is garbage collected and the new one just happens to have the same id as the previous one.

If you initialize the self.samples as a tensor and manipulate it inplace via:

    def __getitem__(self, idx):
        print(f'Worker {torch.utils.data.get_worker_info().id} is used')
        self.samples[:] += 0.1
        return self.samples[idx]

you’ll see that each value is bumped twice only, indicating each worker uses its own dataset copy. The original ds.samples tensor is also keeping its originally assigned values after the DataLoader loop.

1 Like

Thanks for the answer @ptrblck! I was confused because I didn’t understand how multiprocessing works under the hood. At least in Unix, copy-on-write is enabled, meaning that the datasets passed to the workers can’t affect each other (and also the original dataset which runs under the main process). I have also written an answer in StackOverflow to better express these ideas. It would be very helpful if you could comment on the SO question, because I think the answers there are wrong and can potentially lead to misunderstandings for future readers and users of PyTorch.