How to break a single large input among different GPUs for using DDP?


I am using a text dataset which has some really long documents and so fitting one such input on a single machine runs into OOM error.

Any suggestions on how to break the input across multiple GPUs and also efficiently do the backward pass?


I suppose the model and the input should be on the same device. If you have multiple GPUs, one thing you could do is use the mode parallel in pytorch and distribute your model layers across the multiple GPUs and keep only a few or even one layer of the model on the GPU which has your input.
Might not be what you wanted but may help. The mode parallel tutorial is here -

Thank you for your reply. I tried that. I am using multiple BERT so breaking it down across several GPUs is leading to more issues.

Try using Data Parallelism. It’s easy to use if you have access to multiple GPUs. Make sure that the batches are still small enough to fit on a single machine.