I am having a function `nondifferentiable_but_exact` which returns the exact value for the given argument, but is not differentiable. However, I know an approximation for this function, which is not exact but is differentiable and can easily be expressed as a PyTorch expression. Now, instead of manually determining the gradient - which I will do wrong for sure - I would like to simply take the gradient of the differentiable expression and return it for the non differentiable one. I have played around with the following test code for doing this:

``````import torch

@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return a

@staticmethod
a, b = ctx.saved_tensors
# TODO This does not work. The idea is to get the gradient of b and return it for a.

a = nondifferentiable_but_exact(a)
b = differentiable_but_approximate_function(a)

Test.apply(a, b).backward()
``````

Unfortunately, I am unable to get the gradient of `b` inside the `backward` function. I have also tried to call `backward(retain_graph=True)` in there, but this does not seem to work.

Is there a solution to this problem?

Some more background: I want to approximate the gradient for `nondifferentiable_but_exact` with a bilinear interpolation. So I will simply call `nondifferentiable_but_exact` 5 times. 1 for the exact value and 4 times to get some other points for the interpolation. The interpolated value I will discard and only use the gradient. Also, for my use case it is important that I use the exact value of the function. Therefore, I cannot simply use the interpolated value (and be done).

The way I would do this is the following (assuming there are NO learnable parameters in `nondifferentiable_but_exact` and `differentiable_but_approximate_function` and that you don’t need higher order gradients by using create_graph=True):

``````import torch

@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(a)
exact_out = nondifferentiable_but_exact(a)
return exact_out

@staticmethod
a, = ctx.saved_tensors
# Get new Tensor unrelated to a with same content
Thanks a lot for your help! `torch.autograd.grad` was the missing piece 