How to customize the double backward?

Hi !

You may go for

class CubeFunction(torch.autograd.Function):
    """
    Dummy activation function x -> x ** 3
    """
    @staticmethod
    def forward(ctx, X):
        ctx.save_for_backward(X)
        return X ** 3

    @staticmethod
    def backward(ctx, M):
        X, = ctx.saved_tensors
        return CubeFunctionBackward.apply(X, M)


class CubeFunctionBackward(torch.autograd.Function):

    @staticmethod
    def forward(ctx, X, M):
        ctx.save_for_backward(X, M)
        return M * 3 * X ** 2

    @staticmethod
    def backward(ctx, V):
        X, M = ctx.saved_tensors
        return V * 6 * X * M, V * 3 * X ** 2

See this gist https://gist.github.com/arthurmensch/b02c45f3440c9d3de0ef2c0ae5ea1107 for a complete demo.

I could PR this to the documentation if any developer deems it interesting enough to be exposed.

2 Likes