Unexpected NaN gradient propagated when applying boolean conditions

Hi rbregier!

Arul’s explanation is correct. torch.where() backpropagates a gradient
of 0 * nan = nan (or maybe 0 * inf = nan) through the “branch not
taken.”

An example:

>>> import torch
>>> torch.__version__
'1.10.2'
>>> s1 = torch.zeros (4, requires_grad = True)
>>> s2 = torch.zeros (4, requires_grad = True)
>>> t1 = torch.tensor ([0.5, 0.5, 0.0, 0.0])
>>> t2 = torch.tensor ([0.5, 0.5, 1.e-7, 1.e-7])
>>> torch.where (t1 > 1.e-6, s1 / t1, s1).sum().backward()
>>> torch.where (t2 > 1.e-6, s2 / t2, s2).sum().backward()
>>> s1.grad
tensor([2., 2., nan, nan])
>>> s2.grad
tensor([2., 2., 1., 1.])   # no gradients of 1.e7

It’s certainly a known issue, but I don’t think the devs consider it a bug.
Apparently, it’s caused deep inside how autograd works with where()
and would be difficult to fix.

This github issues gives some explanation:

My approach is to get rid of the nans. You can safely feed an incorrect
value to the “zero” branch of torch.where(), as long as it’s not nan
(or inf, etc.).

In your specific case, I would take advantage of the fact that sinc() is an
even function (sinc (-x) = sinc (x)), and clamp() the denominator
away from zero:

sinc_base = torch.sin (x.abs()) / x.abs().clamp (1.e-7)

So two things happen: For x.abs() < 1.e-7, sinc_base will be an
incorrect value, but it won’t be nan. However, for x.abs (x) < 1.e-6,
torch.where() will switch you over to sinc_taylor, so you will never
see the incorrect sinc_base values.

Then, for gradients, when abs (x) < 1.e-6, torch.where() will (in part)
backpropagate 0 * sinc_base_gradient. Although sinc_base_gradient
will be incorrect for abs (x) < 1.e-7, it won’t be nan, so autograd will
correctly backpropagate 0 (rather than nan) for this piece of the of the
abs (x) < 1.e-6 branch.

Best.

K. Frank

1 Like