Is DistributedDataParallel taking up more CPU memory

Hi, I have a time series data shaped as follows (n_samples, n_frequency, n_channels, n_timesteps) = (21600, 9, 62, 1000).

When I load the data using a Customised Dataset on a single machine, I consumed about 40GB of CPU memory.

When I use torchrun and DDP, the CPU memory just multiplies by the number of GPUs I have.

Using torchrun --standalone --nnodes=1 --nproc_per_node=n_gpu, It seems like the DDP process calls the scripts n_gpu times simulatenously, and therefore, loads the data n_gpu times.

Is there anyway to resolve this issue?

Take a look at DistributedSampler it should help restrict data loading to a subset of data torch.utils.data — PyTorch 2.0 documentation