Torch.cuda.empty_cache() and RuntimeError: CUDA out of memory

I have custom loss function. If loss is positive i do backward() as mentioned below in code and after validation free the space using “torch.cuda.empty_cache()”

The code execute fine till the loss is positive. When loss is negative it doesn’t execute Training part mentioned in code as per logic but when it starts training for next batch it encounter “RuntimeError: Cuda out of memory”. After debugging i come to know when i coment torch.cuda.empty_cache() of validation part, error disappear. How comment on that line free up the space? Normally we use put that line to free up the space.

#### Training Part #####

if(loss>0):
        self.backward_G()            
        self.optimizer_G.step() 
torch.cuda.empty_cache()
#### Validation Part #####

with torch.no_grad():
       # Validation image
torch.cuda.empty_cache()

While it’s true that you are freeing the cache so that this memory can be used by other processes, you will not avoid OOM errors (internally empty_cache() will be called and the allocation will be re-executed), so calling empty_cache() manually will just slow down your code.

Ok.

Why it gives this strange behaviour? and till loss is non negative it works fine.

I don’t know and it’s hard to tell how the memory is scattered. You could add debug print statements via print(torch.cuda.memory_summary()) and check the allocations to see if anything stands out.

Ok. Thanks. i will look into it.

I just need to confirm that the below statement free some space when it execute or not. The reason is when loss is >0, it runs normally. Just in case when it is negative, it doesn’t entered in this loop and throw memory error in next epoch.

self.backward_G() will compute the gradients and free the intermediate forward activations, which would reduce the memory usage. I assume you are not calling backward if the loss is negative and might thus keep the computation graph (with the intermediate activations) alive, which could raise the OOM error in the next iteration.

1 Like