Hi all,
I’m using torch.where
to avoid backpropagating through branches that might yield NaN values. Specifically, I want to exponentiate a number if it’s nonnegative and do some other stuff otherwise (as exponentiating a negative number may yield imaginary numbers).
However, when backpropagating from the loss, the resulting gradient is still NaN, even though the loss is the desired one. Here’s a minimal working example. One can use a nullcontext
instead of the autograd anomaly detection to check that a.grad
is NaN.
a = torch.as_tensor(-0.1).requires_grad_()
with torch.autograd.detect_anomaly():
r = torch.where(a < 0, a * 10, a ** 1.8)
r.backward()
In [77]: r
Out[77]: tensor(-1., grad_fn=<SWhereBackward>)
Any ideas on how to circumvent this? Is this considered a bug?
Thanks in advance.