Return variables with no grad from a checkpointed function

I am developing an algorithm which is very memory hungry and I need to use torch.utils.checkpoint to fit it on my GPU. However, there are some training statistics which are easiest to compute alongside the stuff I am checkpointing (doing it separately would be very wasteful because they use many of the same intermediate results). However, the statistics have no gradient and as such cannot be returned from a function wrapped in checkpoint. How can I work around that? Example code looks like the following:

import torch
from torch.utils.checkpoint import checkpoint

def to_be_checkpointed(t):
   intermediate_result_1 = t.pow(2)
   stats_1 = intermediate_result_1.detach().mean()

   intermediate_result_2 = intermediate_result_1.sqrt()
   stats_2 = intermediate_result_2.detach().std()

   output = intermediate_result_2.sum()

   return output, stats_1, stats_2

t = torch.randn(30, 30, requires_grad=True)
out, stats_1, stats_2 = checkpoint(to_be_checkpointed, t)
out.backward()

and on out.backward() I get the following error: RuntimeError: element 1 of tensors does not require grad and does not have a grad_fn. Obviously, one way around this is to not detach the statistics and do it after the checkpoint call but

  1. It is not always possible (some statistics are straight up non-differentiable)
  2. It’s wasteful I think (will the gradient still be calculated just to be discarded?)
  3. It’s non-idiomatic and error prone (I’m writing this post after spending several hours hunting a bug I accidentally introduced this way)

Maybe manually implementing a specialized function would work well for you.
The implementation of checkpointing is a bit tricky, but grabbing CheckpointFunction and amending it to support your stats should not be too hard.

Best regards

Thomas

Cool, this looks promising.

Thank you,
Michał