Unexpected NaN gradient propagated when applying boolean conditions

Hi,

I am facing an unexpected autograd behavior when applying boolean conditions.

Below is a minimal example:

I evaluate the sinc function (equal to sin(x)/x for x!=0 and 1 otherwise), using ether the base definition of sinc or a Taylor series expansion close to 0 to avoid numerical issues.

import torch

for value in (0.0, 1e-7, 1e-3):
    x = torch.as_tensor(value).requires_grad_(True)
    # Base definition valid for x != 0
    sinc_base = torch.sin(x)/x
    # Taylor series expansion valid for x close to 0
    sinc_taylor = 1.0 - x**2 / 6 
    # Switching between the two expressions depending on the value of x
    condition = torch.abs(x) < 1e-6
    sinc = torch.where(condition, sinc_taylor, sinc_base)

    print(f"---- x={x.item()} ----")
    print(f"sinc_base: {sinc_base.item()} -- grad: {torch.autograd.grad(sinc_base, x, retain_graph=True)[0].item()}")
    print(f"sinc_taylor: {sinc_taylor.item()} -- grad: {torch.autograd.grad(sinc_taylor, x, retain_graph=True)[0].item()}")
    print(f"condition:", condition.item())
    print(f"sinc: {sinc.item()} -- grad: {torch.autograd.grad(sinc, x, retain_graph=True)[0].item()}")
    
    
# Output:
# ---- x=0.0 ----
# sinc_base: nan -- grad: nan
# sinc_taylor: 1.0 -- grad: -0.0
# condition: True
# sinc: 1.0 -- grad: nan
# ---- x=1.0000000116860974e-07 ----
# sinc_base: 1.0 -- grad: 0.0
# sinc_taylor: 1.0 -- grad: -3.33333360913457e-08
# condition: True
# sinc: 1.0 -- grad: -3.33333360913457e-08
# ---- x=0.0010000000474974513 ----
# sinc_base: 0.9999998807907104 -- grad: -0.0003662109375
# sinc_taylor: 0.9999998211860657 -- grad: -0.0003333333588670939
# condition: False
# sinc: 0.9999998807907104 -- grad: -0.0003662109375

The gradient of sinc should be finite everywhere. For x=0.0 however, the Taylor series sinc_taylor is used to estimate sinc, yet autograd does not return the gradient of sinc_taylor but NaN.

Do you think it is a bug, or do you have any workaround to suggest?

Here is an even simpler example to illustrate the issue I am facing:

import torch

x = torch.zeros(1, requires_grad=True)
x.squeeze().backward()
print(f"{x=} -- {x.grad=}")
# The gradient of x with respect to itself should be 1.
# We get the expected result:
# --> x=tensor([0.], requires_grad=True) -- x.grad=tensor([1.])

x = torch.zeros(1, requires_grad=True)
other_value = 2 * x + 1
condition = torch.as_tensor(True)
y = torch.where(condition, x, other_value)
y.squeeze().backward()
print(f"{y=} -- {x.grad=}")
# We define y equal to x using a boolean condition.
# The gradient of y with respect to x should be equal to 1.
# We get the expected result:
# --> y=tensor([0.], grad_fn=<SWhereBackward0>) -- x.grad=tensor([1.])

x = torch.zeros(1, requires_grad=True)
nan = 0.0 / x
condition = torch.as_tensor(True)
y = torch.where(condition, x, nan)
y.squeeze().backward()
print(f"{y=} -- {x.grad=}")
# We define y equal to x using a boolean condition.
# The gradient of y with respect to x should be equal to 1.
# We get nan:
# --> y=tensor([0.], grad_fn=<SWhereBackward0>) -- x.grad=tensor([nan])

I could think of the following scenario:

torch.where(condition, x, y) passes back 0 gradient to the elements that are not selected by the condition.

x = torch.zeros(1, requires_grad=True)
out = 0.0 / x
condition = torch.as_tensor(True)
y = torch.where(condition, x, out)

During backward, the gradient of out w.r.t x (d(out) / dx) would encounter nan * 0, which will result in nan.

Hi rbregier!

Arul’s explanation is correct. torch.where() backpropagates a gradient
of 0 * nan = nan (or maybe 0 * inf = nan) through the “branch not
taken.”

An example:

>>> import torch
>>> torch.__version__
'1.10.2'
>>> s1 = torch.zeros (4, requires_grad = True)
>>> s2 = torch.zeros (4, requires_grad = True)
>>> t1 = torch.tensor ([0.5, 0.5, 0.0, 0.0])
>>> t2 = torch.tensor ([0.5, 0.5, 1.e-7, 1.e-7])
>>> torch.where (t1 > 1.e-6, s1 / t1, s1).sum().backward()
>>> torch.where (t2 > 1.e-6, s2 / t2, s2).sum().backward()
>>> s1.grad
tensor([2., 2., nan, nan])
>>> s2.grad
tensor([2., 2., 1., 1.])   # no gradients of 1.e7

It’s certainly a known issue, but I don’t think the devs consider it a bug.
Apparently, it’s caused deep inside how autograd works with where()
and would be difficult to fix.

This github issues gives some explanation:

My approach is to get rid of the nans. You can safely feed an incorrect
value to the “zero” branch of torch.where(), as long as it’s not nan
(or inf, etc.).

In your specific case, I would take advantage of the fact that sinc() is an
even function (sinc (-x) = sinc (x)), and clamp() the denominator
away from zero:

sinc_base = torch.sin (x.abs()) / x.abs().clamp (1.e-7)

So two things happen: For x.abs() < 1.e-7, sinc_base will be an
incorrect value, but it won’t be nan. However, for x.abs (x) < 1.e-6,
torch.where() will switch you over to sinc_taylor, so you will never
see the incorrect sinc_base values.

Then, for gradients, when abs (x) < 1.e-6, torch.where() will (in part)
backpropagate 0 * sinc_base_gradient. Although sinc_base_gradient
will be incorrect for abs (x) < 1.e-7, it won’t be nan, so autograd will
correctly backpropagate 0 (rather than nan) for this piece of the of the
abs (x) < 1.e-6 branch.

Best.

K. Frank

1 Like

Thank you Arul and K. Frank for your replies. I would not have thought that pytorch devs considers this as the expected behavior. I will probably just try to avoid non finite values as you suggested.