Calling autograd.Function in autograd.Function

I find myself needing to manually implement gradient checkpointing for several custom autograd functions. This involves manually “merging” a few autograd.Function instances into a new autograd.Function. How can I achieve this? I would like to reuse each of the forward and backward functions for the nested autograd.Function, and each of these functions might save different data.

One way of doing it is just factor out the reusable logic into helper functions

Have you looked at torch.utils.checkpoint — PyTorch 2.6 documentation by the way, curious why that doesn’t work for your case.

Thanks, my use case involves certain things that makes it easier to manually do it. I’m not entirely sure if checkpoint is the right way, but similar codebase does it manually, and I’m following them.

Would something like this work?

class CustomFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, *args, **kwargs):
        # Creating separate context objects
        ctx_hack1 = torch.autograd.function.FunctionCtx()
        ctx_hack2 = torch.autograd.function.FunctionCtx()
        ctx_hack3 = torch.autograd.function.FunctionCtx()

        # Some operations...
        out = SomeOtherFunction.forward(ctx_hack1, *args)
        out = AnotherFunction.forward(ctx_hack2, out, *args)

        # Context hacks
        ctx.save_for_backward(*ctx_hack1.to_save, *ctx_hack2.to_save)
        ctx.sub_contexts = (ctx_hack1, ctx_hack2, ctx_hack3)
        ctx_hack1.to_save = None
        ctx_hack2.to_save = None

        return out

    @staticmethod
    def backward(ctx, *args, **kwargs):
        # ...

That should work, yes

Thank you, closing this!