Because of how PyTorch’s memory management system works, I get an out of memory error when I still have unused memory. I can deal with this issue by using torch.cuda.empty_cache()
, but doing so basically doubles the time it takes for my code to run.
def closure():
optimizer.zero_grad()
# Run network forward and collect loss
# Run loss backward
torch.cuda.empty_cache()
return loss
optimizer.step(closure)
Is there another way I could control PyTorch’s tendency to hold onto memory that it’s not using? Or can I somehow speed up torch.cuda.empty_cache()
?