Can I have a custom gradient for an input that is not a tensor?
In other words, I want to get rid of the following error when I pass a function to a self-written
function ... returned a gradient different than None at position 1, but the corresponding forward input was not a Variable
Let me explain my use-case: I have a variety of linear operators that look more or less like this:
class LinearOperator(torch.nn.Module): def __init__(self, tensor): super().__init__() self.tensor = tensor def forward(self, x): """ For a linear operator A and a tensor x this implements x -> A(x) """ return something( self.tensor, x ) def rmatvec(self, x): """ This implements the transpose of A applied to a tensor: x -> A^T(x) """ raise somethingTranspose( self.tensor, x )
somethingTranspose are implemented using PyTorch so it should be possible to backprop through them.
While I can store the variable
tensor with which I can compute the linear mapping, I cannot - due to storage space restrictions - store the entire transformation matrix that represents the linear operator
I also have a forward function that takes a linear operator A and a tensor x and produces some result f(A, x).
While I can backpropagate through the function f, I want to implement the gradient differently (i.e. more efficiently) by hand in the backward function.
The code then looks like this:
class MyFunctionF(torch.autograd.Function): @staticmethod def forward(ctx, A, x): ctx.A = A # Linear Operator ctx.x = x # PyTorch tensor result = computeF(A, x) # result is a PyTorch tensor return result @staticmethod def backward(ctx, grad_output): grad_A = something_A(ctx.A, ctx.x, grad_output) grad_x = something_x(ctx.A, ctx.x, grad_output) return grad_A, grad_x
I am using these classes in the following way:
# Both tensor and x are PyTorch Tensors tensor = ... x = ... linOp = LinearOperator(tensor) res = MyFunctionF.apply(linOp, x) linOp1 = OtherLinearOperator(tensor) res1 = MyFunctionF.apply(linOp1, x) linOp2 = NextLinearOperator(tensor) res2 = MyFunctionF.apply(linOp2, x)
Here I have a multiple different classes of linear operators (
The forward pass works fine, but during the backward pass I am getting the error
function MyFunctionFBackward returned a gradient different than None at position 1, but the corresponding forward input was not a Variable
I guess the problem is that
grad_A in the
backward function is a tensor, but the input
forward is not a tensor. Is this correct?
A workaround is to construct the linear operator in the forward function of
MyFunctionF, but this means that
- I have to implement the function f for every type of linear operator.
- I have to backpropagate through the forward method of the linear operator as well by hand, although the gradient of the Linear Operator forward method can be computed with backprop.
class MyFunctionF(torch.autograd.Function): @staticmethod def forward(ctx, tensor, x): A = LinearOperator(tensor) ctx.A = A ctx.x = x result = computeF(A, x) return result @staticmethod def backward(ctx, grad_output): grad_tensor = something_tensor(ctx.A, ctx.x, grad_output) grad_x = something_x(ctx.A, ctx.x, grad_output) return grad_tensor, grad_x
Therefore my question:
Is there a different way to implement linear operators in conjunction with custom backward functions?