Return gradient of other tensor in autograd.Function.backward (to approximate gradient with bilinar interpolation)

I am having a function nondifferentiable_but_exact which returns the exact value for the given argument, but is not differentiable. However, I know an approximation for this function, which is not exact but is differentiable and can easily be expressed as a PyTorch expression. Now, instead of manually determining the gradient - which I will do wrong for sure - I would like to simply take the gradient of the differentiable expression and return it for the non differentiable one. I have played around with the following test code for doing this:

import torch


class Test(torch.autograd.Function):

    @staticmethod
    def forward(ctx, a, b):
        ctx.save_for_backward(a, b)
        return a

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        # TODO This does not work. The idea is to get the gradient of b and return it for a.
        grad_a = b.grad
        assert grad_a is not None
        return grad_a, None

a = torch.ones(1, requires_grad=True)
a = nondifferentiable_but_exact(a)
b = differentiable_but_approximate_function(a)

Test.apply(a, b).backward()

Unfortunately, I am unable to get the gradient of b inside the backward function. I have also tried to call backward(retain_graph=True) in there, but this does not seem to work.

Is there a solution to this problem?

Some more background: I want to approximate the gradient for nondifferentiable_but_exact with a bilinear interpolation. So I will simply call nondifferentiable_but_exact 5 times. 1 for the exact value and 4 times to get some other points for the interpolation. The interpolated value I will discard and only use the gradient. Also, for my use case it is important that I use the exact value of the function. Therefore, I cannot simply use the interpolated value (and be done).

The way I would do this is the following (assuming there are NO learnable parameters in nondifferentiable_but_exact and differentiable_but_approximate_function and that you don’t need higher order gradients by using create_graph=True):

import torch


class Test(torch.autograd.Function):
    @staticmethod
    def forward(ctx, a, b):
        ctx.save_for_backward(a)
        exact_out = nondifferentiable_but_exact(a)
        return exact_out

    @staticmethod
    def backward(ctx, grad_output):
        a, = ctx.saved_tensors
        # Get new Tensor unrelated to a with same content
        local_a = a.detach().requires_grad_()
        with torch.enable_grad():
            approx_out = differentiable_but_approximate_function(a)
        grad = torch.autograd.grad(approx_out, local_a, grad_output)[0]
        return grad

a = torch.ones(1, requires_grad=True)

Test.apply(a).backward()

Thanks a lot for your help! torch.autograd.grad was the missing piece :slight_smile: