Balance memory for DataParallel

nn.DataParallel(model) efficiently parallelizes batches for me. However, when looking at memory, I see that device0 is almost full, while other devices have some memory to spare. Is there a way to balance memory load (e.g. split batches non-equally across devices)?

Thanks for posting question @dyukha Yeah DDP supports uneven inputs starting from pytorch 1.8.1, you can take a look at the details in the doc DistributedDataParallel — PyTorch 1.9.0 documentation

@dyukha please also use DDP instead of Data Parallel, DDP is better to use even in a single process, and we are trying to deprecate Data Parallel in long term as well. see DataParallel — PyTorch 1.9.0 documentation

Thanks for the reply! I have to use DataParallel because of issues with DDP: Distributed Data Parallel example - "process 0 terminated with exit code 1" - #3 by dyukha