Hi,
I have an issue when using DataLoader
, and I am not sure if this is an expected behavior.
I use torch==1.11.0
I am trying to update labels of a datasets at each epoch.
I tried:
dataset.labels = new_labels
and:
dataloader.dataset.labels = new_labels
But it was not working. I found that the issue comes from the fact that I use persistent_workers=True
. If I set persistent_workers=False
this works but the training becomes too long.
Here is a quick way to reproduce the issue:
from torch.utils.data import Dataset, DataLoader
class Foo(Dataset):
def __init__(self, n=15):
self.data = [1] * n
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def bar(persistent_workers=False):
dts = Foo()
loader = DataLoader(dts, batch_size=5, num_workers=5, persistent_workers=persistent_workers)
for batch in loader:
print(batch)
dts.data = [0] * len(dts.data)
print("Changing data")
for batch in loader:
print(batch)
bar(False)
print()
print("Now it is bugging")
print()
bar(True)
So both call have different output, with the first call having the expected behaviour :
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
Changing data
tensor([0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0])
Now it is bugging
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
Changing data
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
Would there be a way to still overwrite the dataset while using persistent_workers=True
?
Thank you for any help,
Elias