Can I have a custom gradient for an input that is not a tensor?
In other words, I want to get rid of the following error when I pass a function to a self-written torch.autograd.Function
function ... returned a gradient different than None at position 1, but the corresponding forward input was not a Variable
My Use-case
Let me explain my use-case: I have a variety of linear operators that look more or less like this:
class LinearOperator(torch.nn.Module):
def __init__(self, tensor):
super().__init__()
self.tensor = tensor
def forward(self, x):
"""
For a linear operator A and a tensor x this implements x -> A(x)
"""
return something( self.tensor, x )
def rmatvec(self, x):
"""
This implements the transpose of A applied to a tensor: x -> A^T(x)
"""
raise somethingTranspose( self.tensor, x )
The functions something
and somethingTranspose
are implemented using PyTorch so it should be possible to backprop through them.
While I can store the variable tensor
with which I can compute the linear mapping, I cannot - due to storage space restrictions - store the entire transformation matrix that represents the linear operator A
.
I also have a forward function that takes a linear operator A and a tensor x and produces some result f(A, x).
While I can backpropagate through the function f, I want to implement the gradient differently (i.e. more efficiently) by hand in the backward function.
The code then looks like this:
class MyFunctionF(torch.autograd.Function):
@staticmethod
def forward(ctx, A, x):
ctx.A = A # Linear Operator
ctx.x = x # PyTorch tensor
result = computeF(A, x)
# result is a PyTorch tensor
return result
@staticmethod
def backward(ctx, grad_output):
grad_A = something_A(ctx.A, ctx.x, grad_output)
grad_x = something_x(ctx.A, ctx.x, grad_output)
return grad_A, grad_x
I am using these classes in the following way:
# Both tensor and x are PyTorch Tensors
tensor = ...
x = ...
linOp = LinearOperator(tensor)
res = MyFunctionF.apply(linOp, x)
linOp1 = OtherLinearOperator(tensor)
res1 = MyFunctionF.apply(linOp1, x)
linOp2 = NextLinearOperator(tensor)
res2 = MyFunctionF.apply(linOp2, x)
Here I have a multiple different classes of linear operators (LinearOperator
, OtherLinearOperator
, NextLinearOperator
, …).
The forward pass works fine, but during the backward pass I am getting the error
function MyFunctionFBackward returned a gradient different than None at position 1, but the corresponding forward input was not a Variable
I guess the problem is that grad_A
in the backward
function is a tensor, but the input A
in forward
is not a tensor. Is this correct?
Workaround
A workaround is to construct the linear operator in the forward function of MyFunctionF
, but this means that
- I have to implement the function f for every type of linear operator.
- I have to backpropagate through the forward method of the linear operator as well by hand, although the gradient of the Linear Operator forward method can be computed with backprop.
class MyFunctionF(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, x):
A = LinearOperator(tensor)
ctx.A = A
ctx.x = x
result = computeF(A, x)
return result
@staticmethod
def backward(ctx, grad_output):
grad_tensor = something_tensor(ctx.A, ctx.x, grad_output)
grad_x = something_x(ctx.A, ctx.x, grad_output)
return grad_tensor, grad_x
Therefore my question:
Is there a different way to implement linear operators in conjunction with custom backward functions?