I am having a function nondifferentiable_but_exact
which returns the exact value for the given argument, but is not differentiable. However, I know an approximation for this function, which is not exact but is differentiable and can easily be expressed as a PyTorch expression. Now, instead of manually determining the gradient - which I will do wrong for sure - I would like to simply take the gradient of the differentiable expression and return it for the non differentiable one. I have played around with the following test code for doing this:
import torch
class Test(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return a
@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
# TODO This does not work. The idea is to get the gradient of b and return it for a.
grad_a = b.grad
assert grad_a is not None
return grad_a, None
a = torch.ones(1, requires_grad=True)
a = nondifferentiable_but_exact(a)
b = differentiable_but_approximate_function(a)
Test.apply(a, b).backward()
Unfortunately, I am unable to get the gradient of b
inside the backward
function. I have also tried to call backward(retain_graph=True)
in there, but this does not seem to work.
Is there a solution to this problem?
Some more background: I want to approximate the gradient for nondifferentiable_but_exact
with a bilinear interpolation. So I will simply call nondifferentiable_but_exact
5 times. 1 for the exact value and 4 times to get some other points for the interpolation. The interpolated value I will discard and only use the gradient. Also, for my use case it is important that I use the exact value of the function. Therefore, I cannot simply use the interpolated value (and be done).