It was possible in 2.6 a year ago. I updated to 2.9 yesterday. Now ctx object is not passed to forward function at all.
The only question is, if I want to store anything for the backward pass, I have to overwrite the setup_context function, and some docs said I should save the data on ctx to pass it to setup_context, but ctx object is not accessable in forward function. Did I miss any detail? Or how do you guys implement this?
Thank you.
update:
class Seperated(torch.autograd.Function):
@staticmethod
def forward(input:Any, *args: Any, **kwargs: Any)->torch.Tensor:
#this is the torch2.x pattern. Torch recommend you do this way.
#no access to the ctx object in this function if it's the seperated pattern.
return input
@staticmethod
def setup_context(ctx:torch.autograd.function.FunctionCtx, inputs, output):
#if you need to do anything with ctx, do it here.
pass
@staticmethod
def backward(ctx:Any, *grads):
#you always get ctx in backward func.
return None
class Combined(torch.autograd.Function):
@staticmethod
def forward(ctx, input:Any, *args: Any, **kwargs: Any)->torch.Tensor:
# this is the torch1.x pattern, you get ctx as the 1st input.
return input
@staticmethod
def backward(ctx:Any, *grads):
#you always get ctx in backward func.
return None
Ok, I believe this is the answer. Thank you anyway.