Because of how PyTorch’s memory management system works, I get an out of memory error when I still have unused memory. I can deal with this issue by using
torch.cuda.empty_cache(), but doing so basically doubles the time it takes for my code to run.
def closure(): optimizer.zero_grad() # Run network forward and collect loss # Run loss backward torch.cuda.empty_cache() return loss optimizer.step(closure)
Is there another way I could control PyTorch’s tendency to hold onto memory that it’s not using? Or can I somehow speed up