Dataloader shows strange behavior for num_workers >0

I am trying to use information from the outside functions to decide which data to return. Here, I have added a simplified code to demonstrate the problem. When I use num_workers = 0, I get the desired behavior (The output after 3 epochs is 18). But, when I increase the value of num_workers, the output after each epoch is the same. And the global variable remains unchanged.

from torch.utils.data import Dataset, DataLoader

x = 6
def getx():
    global x
    x+=1
    print("x: ", x)
    return x

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        global x
        x = getx()
        return x
    
    def __len__(self):
        return 3

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=0,
    shuffle=False
)

for epoch in range(4):
    for idx, data in enumerate(loader):
        print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))

The final output when num_workers=0 is 18 as expected. But when num_workers>0, x remains unchanged (The final output is 6).

How can I get a similar behavior as num_workers=0 using num_workers>0(i.e.How to ensure the __getitem__ function of dataloader changes the global variable x's value )?

It would be great if someone can explain the problem and how to fix it.

1 Like

Good question!

From the docs:

In this mode, each time an iterator of a DataLoader is created (e.g., when you call enumerate(dataloader) ), num_workers worker processes are created. At this point, the dataset , collate_fn , and worker_init_fn are passed to each worker, where they are used to initialize, and fetch data. This means that dataset access together with its internal IO, transforms (including collate_fn ) runs independently in the worker process.

This implies:

  • when num_workers=0, for each epoch, data fetching is done by MAIN process

  • when num_workers=1, for each epoch, MAIN spawns a new worker each time enumerate(loader) is called. At each epoch, you should expect output of 7,8,9

  • when num_workers=2, for each epoch, MAIN spawns 2 new workers each time enumerate(loader) is called. Since len(dataset)=3, worker1 gets 2 indices, worker2 gets 1 index. You should expect output of 7,8 from worker1, and 7 from worker2.

  • If you change len(dataset) to 4,

    • with num_workers=2, you should expect 7,8,7,8
    • with num_workers=3, you should expect 7,8,7,7

To bypass this GIL limitation, you can look into multiprocessing communication constructs.

For this problem, shared tensors can be used.

I asked the same question on stack overflow and the solution works.
Stack overflow Answer

Hope this helps others facing similar problem.

1 Like