I am developing an algorithm which is very memory hungry and I need to use torch.utils.checkpoint
to fit it on my GPU. However, there are some training statistics which are easiest to compute alongside the stuff I am checkpointing (doing it separately would be very wasteful because they use many of the same intermediate results). However, the statistics have no gradient and as such cannot be returned from a function wrapped in checkpoint
. How can I work around that? Example code looks like the following:
import torch
from torch.utils.checkpoint import checkpoint
def to_be_checkpointed(t):
intermediate_result_1 = t.pow(2)
stats_1 = intermediate_result_1.detach().mean()
intermediate_result_2 = intermediate_result_1.sqrt()
stats_2 = intermediate_result_2.detach().std()
output = intermediate_result_2.sum()
return output, stats_1, stats_2
t = torch.randn(30, 30, requires_grad=True)
out, stats_1, stats_2 = checkpoint(to_be_checkpointed, t)
out.backward()
and on out.backward()
I get the following error: RuntimeError: element 1 of tensors does not require grad and does not have a grad_fn
. Obviously, one way around this is to not detach the statistics and do it after the checkpoint
call but
- It is not always possible (some statistics are straight up non-differentiable)
- It’s wasteful I think (will the gradient still be calculated just to be discarded?)
- It’s non-idiomatic and error prone (I’m writing this post after spending several hours hunting a bug I accidentally introduced this way)