I have a custom module which aims to try rearranging values of the input in a sophisticated way(I have to extending autograd) .
Thus the double backward of gradients should be the same as backward of gradients, similar with reshape?
If I define in this way in XXXFunction.py:
@staticmethod
def backward(ctx, grad_output):
# do something to rearrange grad_output to get grad_input
return grad_input
The double backward seems not correct. Then how to customize the double backward?
class CubeFunction(torch.autograd.Function):
"""
Dummy activation function x -> x ** 3
"""
@staticmethod
def forward(ctx, X):
ctx.save_for_backward(X)
return X ** 3
@staticmethod
def backward(ctx, M):
X, = ctx.saved_tensors
return CubeFunctionBackward.apply(X, M)
class CubeFunctionBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, X, M):
ctx.save_for_backward(X, M)
return M * 3 * X ** 2
@staticmethod
def backward(ctx, V):
X, M = ctx.saved_tensors
return V * 6 * X * M, V * 3 * X ** 2
Hey, @amensch I was wondering if I could ask a question about the grad_output (V) in the doublebackward?
From my understanding, grad_output is the gradient of the loss with respect to the output of the function. So, M is d(loss)/df where f = x^3 and the loss is the output of the function. So in effect, d(loss)/df = 1.
I understand that there’s 2 returned gradients where backward1 = M * 3 * x^2 is differentiated with respect to X, and M and then multiplied by V is order to get the correct 2nd order gradient for each variable. Where V would be the derivative of some loss with respect to backward1. But what is this 2nd order loss?