How to customize the double backward?

I have a custom module which aims to try rearranging values of the input in a sophisticated way(I have to extending autograd) .
Thus the double backward of gradients should be the same as backward of gradients, similar with reshape?

If I define in this way in

    def backward(ctx, grad_output):
        # do something to rearrange grad_output to get grad_input
        return grad_input

The double backward seems not correct. Then how to customize the double backward?

Define another Function to do backward of this function (and use it in this backward) and define double backward in that Function’s backward.

Hi, SimonW, I am still a bit confusing. Could you please give some pseudocode about this? Thank you very much.

Hi !

You may go for

class CubeFunction(torch.autograd.Function):
    Dummy activation function x -> x ** 3
    def forward(ctx, X):
        return X ** 3

    def backward(ctx, M):
        X, = ctx.saved_tensors
        return CubeFunctionBackward.apply(X, M)

class CubeFunctionBackward(torch.autograd.Function):

    def forward(ctx, X, M):
        ctx.save_for_backward(X, M)
        return M * 3 * X ** 2

    def backward(ctx, V):
        X, M = ctx.saved_tensors
        return V * 6 * X * M, V * 3 * X ** 2

See this gist for a complete demo.

I could PR this to the documentation if any developer deems it interesting enough to be exposed.


Aha, interesting. I would like to see a PR about this, I don’t know what the developers think about this.

Hey, @amensch I was wondering if I could ask a question about the grad_output (V) in the doublebackward?

From my understanding, grad_output is the gradient of the loss with respect to the output of the function. So, M is d(loss)/df where f = x^3 and the loss is the output of the function. So in effect, d(loss)/df = 1.

I understand that there’s 2 returned gradients where backward1 = M * 3 * x^2 is differentiated with respect to X, and M and then multiplied by V is order to get the correct 2nd order gradient for each variable. Where V would be the derivative of some loss with respect to backward1. But what is this 2nd order loss?

Hopefully, this makes sense!

Thank you!