Hi, Is it possible for you to provide a snippet of your code/a way to reproduce the issue that you are seeing? Similar to DataParallel imbalanced memory usage, it could be the case that the outputs of your forward
pass are being gathered onto a single GPU (GPU 2 in your case), causing it to OOM.