How to avoid nan output from atan2 during backward pass?

Hi everyone, first post here!

I’m training a model that uses a single atan2 in its final stage. I’ve enabled anomaly detection with torch.autograd.set_detect_anomaly(True), and after a while I get an error that “[f]unction ‘Atan2Backward0’ returned nan values in its 0th output.” I know that atan2 can produce nan if both the nominator and denominator are 0, so I add a small epsilon value to any denominator values that are zero:

    # Calculate the phase of a complex spectrogram
    def find_phase(self, subbands: torch.Tensor) -> torch.Tensor:
        numerator = subbands[:, 1:2, :, :]
        denominator = subbands[:, 0:1, :, :]
        # Add epsilon to denominator to avoid nan
        epsilon = 1e-7
        nudge = (denominator == 0) * epsilon
        denominator = denominator + nudge
        return torch.atan2(numerator, denominator)

Strangely, though, the error still occurs. I wonder if I’m missing something here, and if there’s a better way to ensure that atan2 doesn’t return nan. Thanks!

Also, I’m training on a Nvidia A100 with PyTorch 2.0.0 + CUDA 11.7, if that’s any help.

Your approach sounds valid and also works in this minimal code snippet:

numerator = torch.randn(10, 10)
numerator[0, 0] = 0.
numerator.requires_grad_(True)
denominator = torch.randn(10, 10)
denominator[0, 0] = 0.

epsilon = 1e-7
nudge = (denominator == 0) * epsilon
denominator = denominator + nudge
out =  torch.atan2(numerator, denominator)
out.mean().backward()
print(numerator.grad)

Could you check which values are set to NaN in the backward pass and check the inputs to this method?

Thanks, I really appreciate the help :slight_smile:

So I iterated over all of the model parameters after calling loss.backward() and checked for NaN in each of the gradients and weights.

        # Check for non-finite parameters
        all_gradients_are_finite = True
        for parameter_name, parameter in self.model.named_parameters():
            if parameter.grad is not None:
                if not parameter.grad.isfinite().all():
                    print("Parameter gradient is non-finite: {}".format(parameter_name))
                    all_gradients_are_finite = False
            elif parameter.data is not None:
                if not parameter.data.isfinite().all():
                    print("Parameter data is non-finite: {}".format(parameter_name))

None of the weights contain NaN (probably because my training loop avoids calling optimizer.step() if any values are NaN) but all of the gradients do. The first layer with a NaN gradient in the backward pass is the last trainable layer of the model (Conv2d). This makes sense, I suppose, because the atan2 call in question appears after that final layer: it’s part of an output stage where I convert the estimated spectrogram back to a waveform. So if atan2 returns NaN in the backward pass it would propagate to the whole model.

I checked the inputs to the find_phase method and they don’t contain NaN at all during the forward pass. The loss doesn’t contain NaN either (as long as I don’t call optimizer.step() when NaN gradients are detected).

I also tried removing the find_phase method and the NaNs disappeared. So tan2 does seem to be the culprit. I’m a little ignorant about how autograd works, to be honest, but I wonder if adding epsilon here only guarantees a finite output from atan2 during the forward pass, and it’s still possible for the backward pass to feed it a zero-valued numerator and denominator?

Your observation indeed points towards atan2 and you can find the gradient definition here.

This shouldn’t happen as my previous code snippet would only fail without the eps, which also matches the gradient formula.

Could you add a check into the forward method checking for zero outputs e.g. via:

if (out == 0.0).any():
    print(f"zero output detected at {(out==0.0).nonzero()}")

I wasn’t sure exactly where to add this check, but I tried directly after atan2 in the find_phase method as well as at the output of the model itself. Both checks did return zero outputs now and then, although this is to be expected since the model ultimately returns a waveform, and find_phase returns the phase of a complex spectrogram, both of which may contain zero crossings, silent sections, etc. Could zero outputs cause atan2 to return NaN in the backward pass, I wonder?

Based on this code, yes I think so:

numerator = torch.randn(10, 10)
numerator[0, 0] = 0.
numerator.requires_grad_(True)
denominator = torch.randn(10, 10)
denominator[0, 0] = 0.

# epsilon = 1e-10000
# nudge = (denominator == 0) * epsilon
# denominator = denominator + nudge
out =  torch.atan2(numerator, denominator)
out.mean().backward()
print(numerator.grad)
# tensor([[        nan,  4.0503e-04, -3.5165e-03,  5.8801e-03,  1.3173e-02,
#          -2.1934e-02, -3.9802e-03, -1.4191e-03, -2.0235e-03,  7.3573e-03],

if (out == 0.0).any():
    print(f"zero output detected at {(out==0.0).nonzero()}")
# zero output detected at tensor([[0, 0]])

but note that both inputs are also set to zero.
Given you are adding a small eps value to the denominator this should not happen (as seen in my previous code), but still something is causing the issue which is why I suggested to now dig a bit into the actual values.

I figured it out! By digging into the actual values, as you suggested. My find_phase method above replaces zero values with eps, yes, but other values that are very close to zero (e.g. 1e-20) slip through unchanged, and become zero during the backward pass via rounding errors. The solution:

def find_phase(self, subbands: torch.Tensor) -> torch.Tensor:
    numerator = subbands[:, 1:2, :, :]
    denominator = subbands[:, 0:1, :, :]
    # Add epsilon to denominator to avoid NaN in backward pass
    epsilon = 1e-10
    near_zeros = denominator < epsilon
    denominator = denominator * (near_zeros.logical_not())
    denominator = denominator + (near_zeros * epsilon)
    return torch.atan2(numerator, denominator)

This replaces any values less than eps with eps. And no more NaN or inf! Autograd no longer detects any anomalies.

Here’s a minimal example that demonstrates the problem:

numerator = torch.tensor([1e-20], requires_grad=True)
denominator = torch.tensor([1e-20], requires_grad=True)
out =  torch.atan2(numerator, denominator)
out.mean().backward()
print(numerator.grad) # tensor([inf])
print(denominator.grad) # tensor([-inf])

Anyway, all fixed now. Thanks again for pointing me in the right direction! Been scratching my head for hours over this one and very relieved to see it resolved.

Ah, of course! I have even experimented with tiny values but was blind.
Great to hear you have figured it out. :slight_smile:

near_zeros = denominator < epsilon

The code in the previous would be better modified as follows

near_zeros = denominator.abs() < epsilon