It’s certainly a known issue, but I don’t think the devs consider it a bug.
Apparently, it’s caused deep inside how autograd works with where()
and would be difficult to fix.
This github issues gives some explanation:
My approach is to get rid of the nans. You can safely feed an incorrect
value to the “zero” branch of torch.where(), as long as it’s not nan
(or inf, etc.).
In your specific case, I would take advantage of the fact that sinc() is an even function (sinc (-x) = sinc (x)), and clamp() the denominator
away from zero:
So two things happen: For x.abs() < 1.e-7, sinc_base will be an incorrect value, but it won’t be nan. However, for x.abs (x) < 1.e-6, torch.where() will switch you over to sinc_taylor, so you will never
see the incorrect sinc_base values.
Then, for gradients, when abs (x) < 1.e-6, torch.where() will (in part)
backpropagate 0 * sinc_base_gradient. Although sinc_base_gradient
will be incorrect for abs (x) < 1.e-7, it won’t be nan, so autograd will
correctly backpropagate 0 (rather than nan) for this piece of the of the abs (x) < 1.e-6 branch.