# Unexpected NaN gradient propagated when applying boolean conditions

Hi,

I am facing an unexpected autograd behavior when applying boolean conditions.

Below is a minimal example:

I evaluate the sinc function (equal to sin(x)/x for x!=0 and 1 otherwise), using ether the base definition of sinc or a Taylor series expansion close to 0 to avoid numerical issues.

``````import torch

for value in (0.0, 1e-7, 1e-3):
# Base definition valid for x != 0
sinc_base = torch.sin(x)/x
# Taylor series expansion valid for x close to 0
sinc_taylor = 1.0 - x**2 / 6
# Switching between the two expressions depending on the value of x
condition = torch.abs(x) < 1e-6
sinc = torch.where(condition, sinc_taylor, sinc_base)

print(f"---- x={x.item()} ----")
print(f"condition:", condition.item())

# Output:
# ---- x=0.0 ----
# sinc_base: nan -- grad: nan
# sinc_taylor: 1.0 -- grad: -0.0
# condition: True
# sinc: 1.0 -- grad: nan
# ---- x=1.0000000116860974e-07 ----
# sinc_base: 1.0 -- grad: 0.0
# sinc_taylor: 1.0 -- grad: -3.33333360913457e-08
# condition: True
# sinc: 1.0 -- grad: -3.33333360913457e-08
# ---- x=0.0010000000474974513 ----
# sinc_base: 0.9999998807907104 -- grad: -0.0003662109375
# sinc_taylor: 0.9999998211860657 -- grad: -0.0003333333588670939
# condition: False
# sinc: 0.9999998807907104 -- grad: -0.0003662109375
``````

The gradient of sinc should be finite everywhere. For x=0.0 however, the Taylor series sinc_taylor is used to estimate sinc, yet autograd does not return the gradient of sinc_taylor but NaN.

Do you think it is a bug, or do you have any workaround to suggest?

Here is an even simpler example to illustrate the issue I am facing:

``````import torch

x = torch.zeros(1, requires_grad=True)
x.squeeze().backward()
# The gradient of x with respect to itself should be 1.
# We get the expected result:

x = torch.zeros(1, requires_grad=True)
other_value = 2 * x + 1
condition = torch.as_tensor(True)
y = torch.where(condition, x, other_value)
y.squeeze().backward()
# We define y equal to x using a boolean condition.
# The gradient of y with respect to x should be equal to 1.
# We get the expected result:

x = torch.zeros(1, requires_grad=True)
nan = 0.0 / x
condition = torch.as_tensor(True)
y = torch.where(condition, x, nan)
y.squeeze().backward()
# We define y equal to x using a boolean condition.
# The gradient of y with respect to x should be equal to 1.
# We get nan:
``````

I could think of the following scenario:

`torch.where(condition, x, y)` passes back `0` gradient to the elements that are not selected by the `condition`.

``````x = torch.zeros(1, requires_grad=True)
out = 0.0 / x
condition = torch.as_tensor(True)
y = torch.where(condition, x, out)
``````

During backward, the gradient of `out` w.r.t `x` (`d(out) / dx`) would encounter `nan * 0`, which will result in `nan`.

Hi rbregier!

Arul’s explanation is correct. `torch.where()` backpropagates a gradient
of `0 * nan = nan` (or maybe `0 * inf = nan`) through the “branch not
taken.”

An example:

``````>>> import torch
>>> torch.__version__
'1.10.2'
>>> s1 = torch.zeros (4, requires_grad = True)
>>> s2 = torch.zeros (4, requires_grad = True)
>>> t1 = torch.tensor ([0.5, 0.5, 0.0, 0.0])
>>> t2 = torch.tensor ([0.5, 0.5, 1.e-7, 1.e-7])
>>> torch.where (t1 > 1.e-6, s1 / t1, s1).sum().backward()
>>> torch.where (t2 > 1.e-6, s2 / t2, s2).sum().backward()
tensor([2., 2., nan, nan])
tensor([2., 2., 1., 1.])   # no gradients of 1.e7
``````

It’s certainly a known issue, but I don’t think the devs consider it a bug.
Apparently, it’s caused deep inside how autograd works with `where()`
and would be difficult to fix.

This github issues gives some explanation:

My approach is to get rid of the `nan`s. You can safely feed an incorrect
value to the “zero” branch of `torch.where()`, as long as it’s not `nan`
(or `inf`, etc.).

In your specific case, I would take advantage of the fact that `sinc()` is an
even function (`sinc (-x) = sinc (x)`), and `clamp()` the denominator
away from zero:

``````sinc_base = torch.sin (x.abs()) / x.abs().clamp (1.e-7)
``````

So two things happen: For `x.abs() < 1.e-7`, `sinc_base` will be an
incorrect value, but it won’t be `nan`. However, for `x.abs (x) < 1.e-6`,
`torch.where()` will switch you over to `sinc_taylor`, so you will never
see the incorrect `sinc_base` values.

Then, for gradients, when `abs (x) < 1.e-6`, `torch.where()` will (in part)
backpropagate `0 * sinc_base_gradient`. Although `sinc_base_gradient`
will be incorrect for `abs (x) < 1.e-7`, it won’t be `nan`, so autograd will
correctly backpropagate `0` (rather than `nan`) for this piece of the of the
`abs (x) < 1.e-6` branch.

Best.

K. Frank

1 Like

Thank you Arul and K. Frank for your replies. I would not have thought that pytorch devs considers this as the expected behavior. I will probably just try to avoid non finite values as you suggested.