Parametrize autograd Function sub-class implementation in >= 0.2

I’m looking for some help on how to properly parametrize and/or inject logic into a Function sub-class in PyTorch >= 0.2

From my understanding, before version 0.2, I could do the following (adapted from the docs):

class Linear(Function):
   def __init__(self, my_custom_logic):
       self.my_custom_logic = my_custom_logic

    def forward(self, input, weight, bias=None):
        # ...
        # Compute output taking advantage of self.my_custom_logic
        # ...
        return output

    def backward(self, grad_output):
        grad_input = grad_weight = grad_bias = None
        # ...
        # Compute grads taking advantage of self.my_custom_logic
        # ...
        return grad_input, grad_weight, grad_bias

and then call it this way:

Linear(SomeObject.generate_my_custom_logic(my_args)).forward(input, weight)

Clearly my_custom_logic can be anything within the scope of Python language.

However, with PyTorch >= 0.2, forward and backward are static methods, and ctx is shared only between them, so there is no obvious way on how to parametrize the implementation.

Is there any proper solution to this beside resorting to using the deprecated APIs?

Hi,

You can do the following:

class Linear(Function):
    @staticmethod
    def forward(ctx, my_custom_logic, input, weight, bias=None):
        ctx.my_custom_logic = my_custom_logic
        # ...
        # Compute output taking advantage of ctx.my_custom_logic
        # ...
        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_weight = grad_bias = None
        # ...
        # Compute grads taking advantage of ctx.my_custom_logic
        # ...
        return grad_input, grad_weight, grad_bias