I defined a custom autograd function. It is very long. I want to convert the long computation into a few helper functions, so that it is more readable, and maybe the gradient/jacobian of the helper functions can be used directly?
The original implementation
class customFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, data, option):
# some long computation
res = ...
# compute the gradient directly
grad_input = ...
ctx.save_for_backward(grad_input)
return res
@staticmethod
def backward(ctx, grad_output):
grad_input, = ctx.saved_tensors
return grad_output * grad_input, None, None
The converted code looks like
from torch.autograd.functional import jacobian
class customFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, data, option):
# save data and option to be used in helper functions
ctx.data = data
ctx.option = option
# get result with a helper function
res1 = ctx.get_res1(x)
res = someRegularFunc(res1)
# # compute the gradient by using the jacobian of `get_res1()`
grad_res1_x = jacobian(get_res1, x)
grad_input = ...
ctx.save_for_backward(grad_input)
return res
def get_res1(ctx, x):
data, option = ctx.data, ctx.option
res1 = ... # some short computation
return res1
@staticmethod
def backward(ctx, grad_output):
grad_input, = ctx.saved_tensors
return grad_output * grad_input, None, None
However, it does not work. Since @staticmethod
cannot access the properties of the class itself
AttributeError: 'customFuncBackward' object has no attribute 'get_res1'
Is there a way to solve this problem?