torch.where to avoid backpropagating through branches that might yield NaN values. Specifically, I want to exponentiate a number if it’s nonnegative and do some other stuff otherwise (as exponentiating a negative number may yield imaginary numbers).
However, when backpropagating from the loss, the resulting gradient is still NaN, even though the loss is the desired one. Here’s a minimal working example. One can use a
nullcontext instead of the autograd anomaly detection to check that
a.grad is NaN.
a = torch.as_tensor(-0.1).requires_grad_() with torch.autograd.detect_anomaly(): r = torch.where(a < 0, a * 10, a ** 1.8) r.backward()
In : r Out: tensor(-1., grad_fn=<SWhereBackward>)
Any ideas on how to circumvent this? Is this considered a bug?
Thanks in advance.