While it is a bit confusing at first sight, there is little that can be done here, and in particular, where
is working as expected.
When you instrument this as
def f1(x):
return 0/x
def f2(x):
return x.clone()
x = torch.zeros(1, requires_grad=True)
r1 = f1(x)
r1.retain_grad()
r2 = f2(x)
r2.retain_grad()
c = torch.where(x == 0, r2, r1)
c.backward()
print(r1.grad, r2.grad, x.grad)
Here you see that r1.grad is 0, r2.grad is 1 and x.grad is NaN.
The gradient x.grad is NaN because even when you feed a grad of 0 to the backward of f1, you’ll get NaN as the gradient of f1’s input.
The gradient of x is the sum of the input gradients of f1 and f2, i.e. NaN + 1 = NaN.
Best regards
Thomas