The default value of dataloader multiprocessing_context seems to be “spawn” in a spawned process on Unix. I will get OOM unless I set multiprocessing_context="fork" explicitly.
It doesn’t behave as documentation says:
On Unix, fork() is the default multiprocessing start method. Using fork() , child workers typically can access the dataset and Python argument functions directly through the cloned address space.
I use torch.multiprocessing.spawn(main, ...) to run DDP training. The dataloaders are initialized in main(). The memory usage without expilict setting multiprocessing_context is out of limit, as well as setting multiprocessing_context="spawn". However, setting multiprocessing_context="fork" makes memory usage much smaller, like 25G/ 63G. The OOM happens at enumerate(dataloader).
I tried running a simple script that mimics your described code, but I am not able to reproduce the issue - this might be due to having different environments, etc. Do you mind creating a GitHub issue in the PyTorch repo and filling out the detailed environment description there (PT version, CUDA, NCCL version, etc.)? That way, we can try to reproduce this issue in an environment similar to yours. Thanks!