Gradients of torch.where

While it is a bit confusing at first sight, there is little that can be done here, and in particular, where is working as expected.
When you instrument this as

def f1(x):
    return 0/x

def f2(x):
    return x.clone()

x = torch.zeros(1, requires_grad=True)
r1 = f1(x)
r1.retain_grad()
r2 = f2(x)
r2.retain_grad()
c = torch.where(x == 0, r2, r1)
c.backward()
print(r1.grad, r2.grad, x.grad)

Here you see that r1.grad is 0, r2.grad is 1 and x.grad is NaN.

The gradient x.grad is NaN because even when you feed a grad of 0 to the backward of f1, you’ll get NaN as the gradient of f1’s input.
The gradient of x is the sum of the input gradients of f1 and f2, i.e. NaN + 1 = NaN.

Best regards

Thomas

2 Likes