I am trying to use information from the outside functions to decide which data to return. Here, I have added a simplified code to demonstrate the problem. When I use num_workers = 0
, I get the desired behavior (The output after 3 epochs is 18). But, when I increase the value of num_workers
, the output after each epoch is the same. And the global variable remains unchanged.
from torch.utils.data import Dataset, DataLoader
x = 6
def getx():
global x
x+=1
print("x: ", x)
return x
class MyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
global x
x = getx()
return x
def __len__(self):
return 3
dataset = MyDataset()
loader = DataLoader(
dataset,
num_workers=0,
shuffle=False
)
for epoch in range(4):
for idx, data in enumerate(loader):
print('Epoch {}, idx {}, val: {}'.format(epoch, idx, data))
The final output when num_workers=0
is 18 as expected. But when num_workers>0
, x remains unchanged (The final output is 6).
How can I get a similar behavior as num_workers=0
using num_workers>0
(i.e.How to ensure the __getitem__
function of dataloader changes the global variable x
's value )?
It would be great if someone can explain the problem and how to fix it.