Defining an object inside autograd.Function

Hi, I’m trying to define an object inside the autograd.Function, but when i’m calling the object inside the backward, it’s saying the object is not defined…


class PermutationLayer(autograd.Function):
    def forward(ctx, x, x1, x2):
        ctx.save_for_backward(x, x1, x2)
    def backward(ctx, grad_output, ...):
        x, x1, x2 = ctx.saved_tensors

        grad_x1 = want_to_add(x1)

        return None, grad_x, grad_x1, grad_x2

    def want_to_add(ctx, x_in):

Could someone tell me the reason and a way to implement an intended def object to be used inside the autograd.Function?

Thank you for reading!


You want to double check how python classes work.
But because we use static method, we don’t create instances of this class.

To use this function, you can either define it outside of the class. Or as another static method and call it with PermutationLayer. want_to_add().

1 Like