Incorrect gradient calculation with torch.where and NaNs

Hi Carlos!

First, this is a known issue (with no simple fix for torch.where()). See,
for example, github issues 68425 and 70342.

Second, I believe that the best fix is to avoid producing nans (or the values
that subsequently produce nans) in the forward pass, even in the “branch
not taken” of the torch.where(). Instead, compute in their places “safe”
(non-nan-producing) values that can be incorrect, because torch.where()
will successfully discard the incorrect (but non-nan) gradients that they
lead to.

For a concrete example of how feed “safe” values to the “branch not taken”
in a problem analogous to yours, see this post:

Best.

K. Frank

1 Like