Custom autograd function with helper functions

I defined a custom autograd function. It is very long. I want to convert the long computation into a few helper functions, so that it is more readable, and maybe the gradient/jacobian of the helper functions can be used directly?

The original implementation

class customFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, data, option):
        # some long computation
        res = ...
        # compute the gradient directly
        grad_input = ...     
        ctx.save_for_backward(grad_input)
        return res

    @staticmethod
    def backward(ctx, grad_output):
        grad_input, = ctx.saved_tensors
        return grad_output * grad_input, None, None

The converted code looks like

from torch.autograd.functional import jacobian

class customFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, data, option):
        # save data and option to be used in helper functions
        ctx.data = data
        ctx.option = option

        # get result with a helper function
        res1 = ctx.get_res1(x)
        res  = someRegularFunc(res1)

        # # compute the gradient by using the jacobian of `get_res1()`
        grad_res1_x = jacobian(get_res1, x)
        grad_input = ...
        ctx.save_for_backward(grad_input)
        return res

    def get_res1(ctx, x):
        data, option = ctx.data, ctx.option
        res1 = ...      # some short computation
        return res1  

    @staticmethod
    def backward(ctx, grad_output):
        grad_input, = ctx.saved_tensors
        return grad_output * grad_input, None, None

However, it does not work. Since @staticmethodcannot access the properties of the class itself

AttributeError: 'customFuncBackward' object has no attribute 'get_res1'

Is there a way to solve this problem?

Wouldn’t it work to define get_res1 itself as a staticmethod or outside of the custom autograd.Function?

Thanks for your quick reply.

  1. Defining get_res1 as a staticmethod inside the custom autograd.Function does not work.

(I am not familiar with python. Please correct me if I am wrong.)
Reason: staticmethod can’t access or modify the class state. If get_res1 is defined inside, then forward cannot use it.
Reference: Class method vs Static method in Python - GeeksforGeeks.

  1. Defining get_res1 as a regular function outside the custom autograd.Function works.

Here is an example: to compute y = 2(Ax + b), where A=data and b=option

from torch.autograd.functional import jacobian

def get_res1(x, data, option):
    res1 = data@x + option     # some short computation
    return res1  

class customFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, data, option):
        # get result with a helper function
        res1 = get_res1(x, data, option)
        res  = 2 * res1

        # compute the gradient by using the jacobian of `get_res1()`
        grad_res1_x = jacobian(get_res1, (x, data, option))[0]
        grad_input = 2 * grad_res1_x
        ctx.save_for_backward(grad_input)
        return res

    @staticmethod
    def backward(ctx, grad_output):
        grad_input, = ctx.saved_tensors
        return grad_output * grad_input, None, None

Test

data = torch.tensor([[1., 2], [3, 4]])
option = torch.tensor([10., 10.])
x = torch.tensor([1., 1.], requires_grad=True)
s = customFunc.apply(x, data, option).sum()   # s = ((data@x + option)*2).sum()

s.backward()
x.grad           # -> tensor([ 8., 12.])

However, it would be better if get_res1 is bound to the custom autograd.Function since only the custom autograd.Function will use it.

Also, if it is defined outside as a regular function, then many inputs need to pass to get_res1() manually, while if it is a class method, then one can use ctx.xxx to access whatever properties/variables of the class.

My original code works (without any helper functions). I am just seeking a “better” way to implement, that is, to use helper function inside the class to increase readability (and to use jacobian directly in some case). However, if the helper functions need to be defined outside, it is also fine.

I’m unsure which class state you are referring to as you would use the ctx object, wouldn’t you?

class customFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, data, option):
        # save data and option to be used in helper functions
        ctx.data = data
        ctx.option = option

        # get result with a helper function
        res = customFunc.get_res1(ctx, x)
       
        return res

    @staticmethod
    def get_res1(ctx, x):
        data, option = ctx.data, ctx.option
        print(data, option)
        res1 = data
        return res1  

    @staticmethod
    def backward(ctx, grad_output):
        #grad_input, = ctx.saved_tensors
        return grad_output, None, None
    
fun = customFunc.apply
fun(torch.randn(1, 1), torch.rand(1, 1), torch.randn(1, 1))

Now I see why I got the error. In my forward, I call the helper function by ctx.get_res1(), which returns error:

AttributeError: ‘customFuncBackward’ object has no attribute ‘get_res1’

Whereas in your forward, you use customFunc.get_res1().

I thought the ctx is something like self, and we could use self.some_method().

Thanks for your help.

So is it true to say that ctx plays a different role as self in a regular class?

If so, it is fine to define or save some variables/attributes and use them later in other methods? That is, is it safe/correct do the following:

ctx.a = a   # in `forward` method
a = ctx.a   # in some other method

Yes, you can store arbitrary objects in the ctx context, but need to be careful with tensors.
To store tensors you should use ctx.save_for_backward as you could create e.g. memory leaks otherwise.