Why torch.nn.DataParallel causes unexpected CUDA OOM?


I’m using torch.nn.DataParallel to do single-node data parallelism , and I’m wondering the following: how should the DataLoader batch be scaled?

I’m asking since I have a code running fine with batch 16 on a T4 GPU, but doing CUDA OOM with batch 416 = 64 (and even with 48!) with torch.nn.DataParallel over 4x T4 GPUs. Is torch.nn.DataParallel doing anything weird with the memory, so that it has less memory available than N1 GPU memory? or is torch.nn.DataParallel already applying a scaling rule so that the dataloader batch is the per-GPU batch and not the SGD-level batch? (I don’t think that’s the case as the doc says it “splits the input across the specified devices by chunking in the batch dimension”)

note:I know PyTorch recommends DDP even for single-node data parallel, but honestly I’m not smart enough to figure out how to use all those torchrun/torch.distributed/launch.py tool, MPI, local_rank things and couldn’t make DDP work after a week and 7 issues opened :slight_smile:

nn.DataParallel will transfer the data to GPU0 and scatter latter. So it is easy to cause OOM problems. You can refer to pytorch ddp to understand how to leverage DDP if you can read Chinese.

1 Like

ok so if I understand correctly, in nn.DataParallel, all the cluster data must be able to fit in one GPU? That does not make any sense right? Because if all the data fits in one GPU, people would not use data parallelism in the first place :upside_down_face: :face_with_raised_eyebrow:

The data would still be split in its batch dimension and the forward activations would thus be shared between all devices, which are usually much larger than the input data.
However, @techkang is right and the scatter/gather ops from the default device create a memory imbalance which is why we recommend the usage of DDP (besides DDP also being faster).