Hi Carlos!
First, this is a known issue (with no simple fix for torch.where()
). See,
for example, github issues 68425 and 70342.
Second, I believe that the best fix is to avoid producing nan
s (or the values
that subsequently produce nan
s) in the forward pass, even in the “branch
not taken” of the torch.where()
. Instead, compute in their places “safe”
(non-nan
-producing) values that can be incorrect, because torch.where()
will successfully discard the incorrect (but non-nan
) gradients that they
lead to.
For a concrete example of how feed “safe” values to the “branch not taken”
in a problem analogous to yours, see this post:
Best.
K. Frank