I’m looking for some help on how to properly parametrize and/or inject logic into a Function sub-class in PyTorch >= 0.2
From my understanding, before version 0.2, I could do the following (adapted from the docs):
class Linear(Function):
def __init__(self, my_custom_logic):
self.my_custom_logic = my_custom_logic
def forward(self, input, weight, bias=None):
# ...
# Compute output taking advantage of self.my_custom_logic
# ...
return output
def backward(self, grad_output):
grad_input = grad_weight = grad_bias = None
# ...
# Compute grads taking advantage of self.my_custom_logic
# ...
return grad_input, grad_weight, grad_bias
and then call it this way:
Linear(SomeObject.generate_my_custom_logic(my_args)).forward(input, weight)
Clearly my_custom_logic can be anything within the scope of Python language.
However, with PyTorch >= 0.2, forward and backward are static methods, and ctx is shared only between them, so there is no obvious way on how to parametrize the implementation.
Is there any proper solution to this beside resorting to using the deprecated APIs?