I find myself needing to manually implement gradient checkpointing for several custom autograd functions. This involves manually “merging” a few autograd.Function
instances into a new autograd.Function
. How can I achieve this? I would like to reuse each of the forward and backward functions for the nested autograd.Function
, and each of these functions might save different data.
One way of doing it is just factor out the reusable logic into helper functions
Have you looked at torch.utils.checkpoint — PyTorch 2.6 documentation by the way, curious why that doesn’t work for your case.
Thanks, my use case involves certain things that makes it easier to manually do it. I’m not entirely sure if checkpoint
is the right way, but similar codebase does it manually, and I’m following them.
Would something like this work?
class CustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args, **kwargs):
# Creating separate context objects
ctx_hack1 = torch.autograd.function.FunctionCtx()
ctx_hack2 = torch.autograd.function.FunctionCtx()
ctx_hack3 = torch.autograd.function.FunctionCtx()
# Some operations...
out = SomeOtherFunction.forward(ctx_hack1, *args)
out = AnotherFunction.forward(ctx_hack2, out, *args)
# Context hacks
ctx.save_for_backward(*ctx_hack1.to_save, *ctx_hack2.to_save)
ctx.sub_contexts = (ctx_hack1, ctx_hack2, ctx_hack3)
ctx_hack1.to_save = None
ctx_hack2.to_save = None
return out
@staticmethod
def backward(ctx, *args, **kwargs):
# ...
That should work, yes
Thank you, closing this!