Gradients of torch.where


I am trying to calculate gradients of a function that uses torch.where, however it results in unexpected gradients. I basically use it to choose between some real case, complex case and limit case where some of the cases will have a Nan gradient for some specific input. For simplicity consider the following example:

def f1(x):
    return 0/x

def f2(x):
    return x

def g(x):
    r1 = f1(x)
    r2 = f2(x)
    return torch.where(x == 0.0, r2, r1)

x = torch.autograd.Variable(torch.zeros(1), requires_grad=True)
print(torch.autograd.grad(f(x), x))
>> (tensor([nan]),)
print(torch.autograd.grad(f2(x), x))
>> (tensor([1.]),)

I would expect that the gradient of f(x) to be the same as f2(x) but they are clearly different. This problem only happens when one of the functions has a Nan gradient i.e. if I change f1 to something like x**2 I get the correct result. I do not know how gradients are propagated through the torch.where function, but if it is something like

torch.where(condition, tensor1, tensor2) = condition * tensor1 + (1-condition) * tensor2

then it make sense that the resulting gradient is Nan, even if condition = 0 and only tensor2 has a Nan gradient.

Is this the intended behavior of torch.where or is there missing some Nan checking missing? If it is intended, it there a way to get around it?

Thank you

While it is a bit confusing at first sight, there is little that can be done here, and in particular, where is working as expected.
When you instrument this as

def f1(x):
    return 0/x

def f2(x):
    return x.clone()

x = torch.zeros(1, requires_grad=True)
r1 = f1(x)
r2 = f2(x)
c = torch.where(x == 0, r2, r1)
print(r1.grad, r2.grad, x.grad)

Here you see that r1.grad is 0, r2.grad is 1 and x.grad is NaN.

The gradient x.grad is NaN because even when you feed a grad of 0 to the backward of f1, you’ll get NaN as the gradient of f1’s input.
The gradient of x is the sum of the input gradients of f1 and f2, i.e. NaN + 1 = NaN.

Best regards


It seems connected to a thing I noted a while ago: Bug or feature? NaNs influence other variables in backprop
In that case there was a 0* inside the code of PyTorch that did 0*NaN and obtained NaN. The workaround for me has been to substitute NaNs with zeros.

Anyway, even if it’s expected behavior, I don’t feel the PyTorch implementation is the most reasonable