Speeding up cache clearning to prevent out of memory error?

Because of how PyTorch’s memory management system works, I get an out of memory error when I still have unused memory. I can deal with this issue by using torch.cuda.empty_cache(), but doing so basically doubles the time it takes for my code to run.

         def closure():
             optimizer.zero_grad()

             # Run network forward and collect loss
             # Run loss backward

             torch.cuda.empty_cache()
            
             return loss
         optimizer.step(closure)

Is there another way I could control PyTorch’s tendency to hold onto memory that it’s not using? Or can I somehow speed up torch.cuda.empty_cache()?

torch.cuda.empty_cache() shouldn’t give you extra memory. This suggests a bug. Do you have a script for us to reproduce the issue?