Is it possible to access ctx object in forward function now in 2.9?

update at bottom.

It was possible in 2.6 a year ago. I updated to 2.9 yesterday. Now ctx object is not passed to forward function at all.

The only question is, if I want to store anything for the backward pass, I have to overwrite the setup_context function, and some docs said I should save the data on ctx to pass it to setup_context, but ctx object is not accessable in forward function. Did I miss any detail? Or how do you guys implement this?

Thank you.

update:

    class Seperated(torch.autograd.Function):
        @staticmethod
        def forward(input:Any, *args: Any, **kwargs: Any)->torch.Tensor:
            #this is the torch2.x pattern. Torch recommend you do this way.
            #no access to the ctx object in this function if it's the seperated pattern.            
            return input

        @staticmethod
        def setup_context(ctx:torch.autograd.function.FunctionCtx, inputs, output):
            #if you need to do anything with ctx, do it here.
            pass

        @staticmethod
        def backward(ctx:Any, *grads):
            #you always get ctx in backward func.
            return None

    class Combined(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input:Any, *args: Any, **kwargs: Any)->torch.Tensor:
            # this is the torch1.x pattern, you get ctx as the 1st input.
            return input

        @staticmethod
        def backward(ctx:Any, *grads):
            #you always get ctx in backward func.
            return None

Ok, I believe this is the answer. Thank you anyway.

Hi @YagaoDirac Could you provide a code example of what you mean here please.

The example code for torch.autograd.Function that takes in ctx in forward still runs

Also see this, which has not changed since 2.6

2 Likes

Thank you for the clue. I figured it out.