Hello,

I am trying to calculate gradients of a function that uses torch.where, however it results in unexpected gradients. I basically use it to choose between some real case, complex case and limit case where some of the cases will have a Nan gradient for some specific input. For simplicity consider the following example:

``````def f1(x):
return 0/x

def f2(x):
return x

def g(x):
r1 = f1(x)
r2 = f2(x)

>> (tensor([nan]),)
>> (tensor([1.]),)
``````

I would expect that the gradient of f(x) to be the same as f2(x) but they are clearly different. This problem only happens when one of the functions has a Nan gradient i.e. if I change f1 to something like x**2 I get the correct result. I do not know how gradients are propagated through the torch.where function, but if it is something like

``````torch.where(condition, tensor1, tensor2) = condition * tensor1 + (1-condition) * tensor2
``````

then it make sense that the resulting gradient is Nan, even if condition = 0 and only tensor2 has a Nan gradient.

Is this the intended behavior of torch.where or is there missing some Nan checking missing? If it is intended, it there a way to get around it?

Thank you

While it is a bit confusing at first sight, there is little that can be done here, and in particular, `where` is working as expected.
When you instrument this as

``````def f1(x):
return 0/x

def f2(x):
return x.clone()

r1 = f1(x)
r2 = f2(x)
c = torch.where(x == 0, r2, r1)
c.backward()
``````