You have two tensors x,z. Given z, say you have the analytical formula for x = x(z) as a bijective function. So the gradient during training can be calculated. Now suppose you want to calculate z(x), when given x. This inverse function can only be calculated numerically (eg. bisection search) and to do this you would have to detach the result of the inversion. Obviously during this, it won’t be possible to calculate the gradient dz/dx with automatic differentiation. It is however possible to calculate the gradient as (dx/dz)^(-1). Theoretically, as well as with autograd.grad.
I am facing the following questions:
Is there a way to reconnect the differentiation graph after .detach() for the bisection search ?
Is there a way to supply the correct gradients (dx/dz)^(-1) for the detached connection ?
Yes, you would want to write a custom autograd.Function. Note that you don’t need to detach the tensor as the forward is implicitly run
Yes, you can use torch.autograd.grad to get dx/dz after computing x with a gradient-requiring copy of z in an torch.enable_grad() block . I would probably do this in the backward.