Gradient calculation for numerically inverted function

You have two tensors x,z. Given z, say you have the analytical formula for x = x(z) as a bijective function. So the gradient during training can be calculated. Now suppose you want to calculate z(x), when given x. This inverse function can only be calculated numerically (eg. bisection search) and to do this you would have to detach the result of the inversion. Obviously during this, it won’t be possible to calculate the gradient dz/dx with automatic differentiation. It is however possible to calculate the gradient as (dx/dz)^(-1). Theoretically, as well as with autograd.grad.

I am facing the following questions:

  • Is there a way to reconnect the differentiation graph after .detach() for the bisection search ?
  • Is there a way to supply the correct gradients (dx/dz)^(-1) for the detached connection ?

Any advice is welcome :slight_smile: Cheers


Yes, you would want to write a custom autograd.Function. Note that you don’t need to detach the tensor as the forward is implicitly run

Yes, you can use torch.autograd.grad to get dx/dz after computing x with a gradient-requiring copy of z in an torch.enable_grad() block . I would probably do this in the backward.

Best regards


Thanks a lot for your answer!
In case there is interest. This is the solution I got at that time:

from torch.autograd import grad,Function

class MyFunctionInverse(Function):

    def forward(ctx, x, par_1):

        # get inverse
        func_partial = partial(myfunction, par_1=par_1)
        z = bisection_invert(func_partial, x)

        # save for backward
        ctx.save_for_backward(z, par_1)

        return z

    def backward(ctx, grad_z):

        with torch.enable_grad():

            z, par_1 = map(
                lambda t: t.detach().clone().requires_grad_(), ctx.saved_tensors)

            x = myfunction(z, par_1)

            grad_x_inverse, = grad(x, z, torch.ones_like(x), retain_graph=True)
            grad_x = grad_z*grad_x_inverse.pow(-1)

            # minus due to parameter dependency
            grad_par_1 = grad(
                x, [par_1], - grad_x)

        return grad_x, grad_par_1

Have you set all your parameters of the model in [par_1]? such as [a,b,c,d]?

Hi :slight_smile: par_1 is a parameter of the function. Like eg. the slope in a linear function: f(x;a) = a*x. You can use several parameters.