Out of memory error after certain epoch

I reviewed the posts linked with this topic where folks encountered this error.

I eliminated all intermittent variables that may have grad graphs and used item() and detach() on all variables after the backward process was completed, then for testing and validation, I used torch.no grad() to not retain the grads during those processes.

I have a DataParallel GAN model with five 24 GB GPUs. The Generator consists of a UNet, 5 CNN layers for the Discriminator, a Wasserstein loss with gradient penalty as an adversarial loss, and MSE for the Generator. The input and reference images are both 512x512 pixels.

When I use nvidia-smi, I have 4 GB free on each GPU during training because I set the batch size to 16. There is even more free space upon validation (round 8 GB on each).

However, after a certain number of epochs, say 30ish, I receive an out of memory error, despite the fact that the available free GPU does not change significantly during training.

If I set the batch size to something unreasonably small, like 3 or 5, it will work (at least for more than 100 epochs), but it will take more than 20 to 30 minutes to complete a single epoch, which is unacceptable. I’d love it if you could share a solution to this problem with us so that many people like me can use it to have a sustainable training process.

I’m running Ubuntu 20 along with torch 1.10.1+cu113.

It’s hard to speculate what might be causing the issue, as your steps sounds reasonable and based on your description (“I have 4 GB free on each GPU during training”) no OOM should be raised at all.
I would recommend to try to narrow down which operation is causing the OOM by checking the stacktrace and by adding debug print statements showing the current memory usage.

I figured out what the issue was.

I was calculating the MSSSIM of the results using the TorchMetrics library. I recognized that the MSSIME class will act as a loss function and will contain the gradient graphs that will cause the issue. I was able to address the problem by replacing the MSSIM class with mmssim functional.