DataLoader(..., num_workers>0, ...) does not update Dataset

Hey,

Please have a look at the following example:

import numpy as np
import torch

from torch.utils.data import Dataset, DataLoader


class MyDataset(Dataset):
    def __init__(self):
        self.data = np.random.randint(0, 100, (10, 2, 2))
        self.max_id = np.ones(len(self.data), dtype=np.int16)


    def __len__(self):
        return len(self.data)


    def __getitem__(self, idx):
        self.max_id[idx] = np.max(self.data[idx, :, :])
        return self.data[idx, :, :]


def main():
    np.random.seed(42)
    torch.manual_seed(42)
    dataset = MyDataset()
    loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
    loader_iter = iter(loader)

    print(f'Initial State: {dataset.max_id}')

    batch = next(loader_iter)

    print(f'After num_worker=1: {dataset.max_id}')

    loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)
    loader_iter = iter(loader)
    batch = next(loader_iter)

    print(f'After num_worker=0: {dataset.max_id}')


if __name__ == '__main__':
    main()

In the last case, where DataLoader(..., num_workers=0, ...), I observe the intended behaviour: dataset.max_id is updated.

How can I achieve this with e.g. DataLoader(..., num_workers=1, ...)?

I am using this in the context of large image files, where I want to read information from these images only upon calling them for training/validation. Thank you for any hint/advice!

Cheers

Is the question trivial? Is the explanation insufficient?

I would very much appreciate any hint as to what property of the DataLoader causes this behaviour and how to best circumvent it.

@dsethz When you specify num_workers > 0, multiple child processes are spawned to perform the actual data loading. As a result, the dataset object would be updated in the child processes and not the parent process where you are printing max_id.

One way to get the max_ids from the child processes would be to put it in a multiprocessing queue on the child processes and read from the same queue on the parent proces.

Hey @pritamdamania87,

thank you for your feedback. I am not experienced with multiprocessing, but if I understand you correctly, I need a custom data loader in which I adapt _MultiProcessingDataLoaderIter ?

Cheers

I don’t think you need a custom dataloader and using a multiprocessing queue in the dataset should suffice:

# Initialize a queue.
import multiprocessing as mp
q = mp.Queue()


# Pass the queue to the dataset
class MyDataset(Dataset):
    def __init__(self, q):
        self.data = np.random.randint(0, 100, (10, 2, 2))
        self.max_id = np.ones(len(self.data), dtype=np.int16)
        self.q = q
        q.put(self.max_id)


    def __len__(self):
        return len(self.data)


    def __getitem__(self, idx):
        self.max_id[idx] = np.max(self.data[idx, :, :])
        q.put(self.max_id)
        return self.data[idx, :, :]

# Then in the main process

def main():
    np.random.seed(42)
    torch.manual_seed(42)
    dataset = MyDataset(q)
    loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
    loader_iter = iter(loader)

    print(f'Initial State: {dataset.max_id}')

    batch = next(loader_iter)

    max_id = q.get()
    print(f'After num_worker=1: {max_id}')
1 Like