Dataset overwrite is not possible when `persistent_workers=True`

Hi,

I have an issue when using DataLoader, and I am not sure if this is an expected behavior.

I use torch==1.11.0

I am trying to update labels of a datasets at each epoch.
I tried:

dataset.labels = new_labels

and:

dataloader.dataset.labels = new_labels

But it was not working. I found that the issue comes from the fact that I use persistent_workers=True. If I set persistent_workers=False this works but the training becomes too long.

Here is a quick way to reproduce the issue:

from torch.utils.data import Dataset, DataLoader


class Foo(Dataset):

    def __init__(self, n=15):
        self.data = [1] * n

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def bar(persistent_workers=False):
    dts = Foo()
    loader = DataLoader(dts, batch_size=5, num_workers=5, persistent_workers=persistent_workers)

    for batch in loader:
        print(batch)

    dts.data = [0] * len(dts.data)

    print("Changing data")
    for batch in loader:
        print(batch)


bar(False)
print()
print("Now it is bugging")
print()
bar(True)

So both call have different output, with the first call having the expected behaviour :

tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
Changing data
tensor([0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0])

Now it is bugging

tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
Changing data
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1])

Would there be a way to still overwrite the dataset while using persistent_workers=True ?

Thank you for any help,
Elias

No, I don’t think that’s easily possible, as each worker would use a copy of the dataset so manipulations would not be reflected. You would either need to recreate the DataLoader (after manipulating the Dataset) or would need to use your first approach and avoid persistent workers.
There might be a hacky way to use the Dataset.__getitem__ method with the worker information, but I don’t know if this would work (and it doesn’t sound too clean to be honest).

Hi, thank you for your answer!