In my case, I perform a lot of computation outside of the graph (under torch.no_grad
context), where many large tensors are being computed or randomly generated at every forward pass.
I just inserted four torch.cuda.empty_cache()
calls throughout my ops at every forward pass, which resulted in ~20% slowdown, but I’m able to increase my batch size from 9 to 14, which is a good trade-off for me. So far I haven’t run into any issues, and the model (custom VGG-like network) trains fine.
I empty the cache immediately after I delete a tensor with del
command.
Could you suggest any additional measures or alternatives to save memory?