NaN values in torch.where's unselected branch still break gradients

Hi all,

I’m using torch.where to avoid backpropagating through branches that might yield NaN values. Specifically, I want to exponentiate a number if it’s nonnegative and do some other stuff otherwise (as exponentiating a negative number may yield imaginary numbers).

However, when backpropagating from the loss, the resulting gradient is still NaN, even though the loss is the desired one. Here’s a minimal working example. One can use a nullcontext instead of the autograd anomaly detection to check that a.grad is NaN.

a = torch.as_tensor(-0.1).requires_grad_()
with torch.autograd.detect_anomaly():
    r = torch.where(a < 0, a * 10, a ** 1.8)
In [77]: r
Out[77]: tensor(-1., grad_fn=<SWhereBackward>)

Any ideas on how to circumvent this? Is this considered a bug?

Thanks in advance.


You can see github issues like this one that discuss this:

Hope this helps.