Hello,
I am trying to calculate gradients of a function that uses torch.where, however it results in unexpected gradients. I basically use it to choose between some real case, complex case and limit case where some of the cases will have a Nan gradient for some specific input. For simplicity consider the following example:
def f1(x):
return 0/x
def f2(x):
return x
def g(x):
r1 = f1(x)
r2 = f2(x)
return torch.where(x == 0.0, r2, r1)
x = torch.autograd.Variable(torch.zeros(1), requires_grad=True)
print(torch.autograd.grad(f(x), x))
>> (tensor([nan]),)
print(torch.autograd.grad(f2(x), x))
>> (tensor([1.]),)
I would expect that the gradient of f(x) to be the same as f2(x) but they are clearly different. This problem only happens when one of the functions has a Nan gradient i.e. if I change f1 to something like x**2 I get the correct result. I do not know how gradients are propagated through the torch.where function, but if it is something like
torch.where(condition, tensor1, tensor2) = condition * tensor1 + (1-condition) * tensor2
then it make sense that the resulting gradient is Nan, even if condition = 0 and only tensor2 has a Nan gradient.
Is this the intended behavior of torch.where or is there missing some Nan checking missing? If it is intended, it there a way to get around it?
Thank you