I’m trying to load a large npy file for training and I’m using IterableDataset. However, the memory cost increase when I set num_worker > 0 due to dataset replicated on each worker process.
So is there any solution for me to just read the dataset with one worker then deal with it seperately with each worker?
https://pytorch.org/docs/stable/data.html?highlight=concat#torch.utils.data.IterableDataset
this doc can solve your question, define a worker_init_fn
that configures each dataset copy differently