LayerNorm's grads become NaN after first epoch


I’ve got a network containing:

Input → LayerNorm → LSTM → Relu → LayerNorm → Linear → output

With gradient clipping set to a value around 1. After the first training epoch, I see that the input’s LayerNorm’s grads are all equal to NaN, but the input in the first pass does not contain NaN or Inf so I have no idea why this is happening or how to prevent it from happening. If I try removing some components (e.g. the second LayerNorm) the NaN grad just moves to another component (such as Linear.bias).

The only thing that comes to mind is that I’ve got AMP mode enabled. After the first iteration the output gets mapped by a loss function to a value around ~100. During the backward pass AMP scales this loss value up to ~600,000. Is it possible that this value is propagated backward and eventually causes the LayerNorm’s gradients to become NaN?

I’ve seen similar errors occurring even with AMP disables so I’m wondering: big picture, how am I supposed to guard against this? I mean, it doesn’t sound like I’m doing anything wrong per-se. The inputs are valid. I’ve got gradient clipping enabled. And yet this error still occurs.

Any ideas?

Could you try running your code with torch.autograd.detact_anomaly? Automatic differentiation package - torch.autograd — PyTorch 1.9.1 documentation

You could also check for an overflow when scaling your gradient? If you’re running with a low bit-wdith, scaling too high might cause an overflow which would create NaNs when backproping! Also, check for Infs as they can lead to NaNs under certain operations!

detect_anomaly is already enabled which is how I discovered this problem to begin with.

You could also check for an overflow when scaling your gradient?

How do I do that? What class/method do I set a breakpoint in?

If you’re running with a low bit-wdith, scaling too high might cause an overflow which would create NaNs when backproping!

Okay, but I’m not the one deciding how high values are being scaled so if this kind of overflow happens in valid networks how do you deal with it? I was under the impression that gradient clipping was supposed to fix this but it does not.

Also, check for Infs as they can lead to NaNs under certain operations!

I already checked the input and output and neither contain NaN or Inf.

What’s the output from this? It should give a bit more detail as to why it crashes? I assume it’s related to the question? LayerNorm’s grads become NaN? If so, which LayerNorm as you have multiple within your code.

if there are no NaNs / Infs within your code is this just on the output of given Layers? You could try using register_full_backward_hook and see if any NaNs / Infs get passed back through that way. Given the NaN is in the gradient it’s most likely happening when computing your loss. You can create a function like

def _save_output(self, module, grad_input, grad_output):
  print(module, grad_output)

and then apply that function to all LayerNorm module within your network like,

for module in model.modules():
#or you can manually place them of the LayerNorm modules yourself (in the loop doesn't work)

Then when you backprop your loss it’ll print out the gradient of the loss with respect to each layer’s output. That might help debug what layer (more specifically which LayerNorm in your case) is causing the NaN issue. Granted the gradient of your loss with respect to the parameters of a layer differs slightly to the grad_output variable, it’s still using in computing the gradient and if it has a NaN it’ll show you what Layer’s failing.

Just to check you’re only using AMP within the forward pass? (Like the tutorial: Automatic Mixed Precision package - torch.cuda.amp — PyTorch 1.9.1 documentation). Backproping with AMP enabled might give rise to your NaNs? (Could perhaps try without using AMP in the backward pass and see if the LayerNorm.grad is NaN goes away?)

Here is a self-contained testcase for reproducing the problem:

import math
import os
from typing import Optional, Tuple

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor
from import DataLoader, Subset, Dataset

# Bug report:
# Debugging "one of the variables needed for gradient computation has been modified by an inplace operation"

class OutdoorTemperatureDataset(Dataset):
    def __init__(self, batch_size: int):
        self.batch_size = batch_size
        self.input_horizon = 1
        self.output_horizon = 2
        self.total_horizon = self.input_horizon + self.output_horizon
        self.outdoor_temperature = torch.tensor([1.0]).repeat(2, self.total_horizon)

    def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
        samples = torch.stack([self.outdoor_temperature[index]])
        # Convert [features, samples] to [samples, features]
        samples = samples.permute(1, 0)
        x = samples[:self.input_horizon, :]
        y = samples[self.input_horizon:, 0]
        return x, y

    def __len__(self):
        return self.outdoor_temperature.shape[0]

class ProcessContext:
    def __init__(self, dataset: OutdoorTemperatureDataset):
        self.input_horizon = dataset.input_horizon
        self.output_horizon = dataset.output_horizon
        train_size = max(1,
                         min(len(dataset) - 1,
                             math.ceil(len(dataset) * 0.9)))
        val_size = len(dataset) - train_size
        assert train_size > 0
        assert val_size > 0
        self.train_dataset, self.val_dataset =
            Subset(dataset, range(0, (train_size + val_size))),
            [train_size, val_size])

    def get_train_dataset(self):
        return self.train_dataset

    def get_validation_dataset(self):
        return self.val_dataset

    def get_model(self, learning_rate: float, max_epochs: int, hidden_layer_size: int, batch_size: int):
        return Predictor(self.train_dataset, self.val_dataset, self.input_horizon, self.output_horizon,
                         learning_rate=learning_rate, max_epochs=max_epochs,
                         hidden_layer_size=hidden_layer_size, batch_size=batch_size)

class Predictor(LightningModule):
    def __init__(self, train_dataset: Dataset, val_dataset: Dataset, input_horizon: int, output_horizon: int,
                 learning_rate: float, max_epochs: int, hidden_layer_size: int, batch_size: int):
        super(Predictor, self).__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.input_horizon = input_horizon
        self.output_horizon = output_horizon
        self.total_horizon = self.input_horizon + self.output_horizon
        self.max_epochs = max_epochs
        self.learning_rate = learning_rate
        self.hidden_layer_size = hidden_layer_size

        self.input_norm = nn.LayerNorm(1)
        self.layer_norm = nn.LayerNorm(self.hidden_layer_size)
        self.lstm = nn.LSTM(1, self.hidden_layer_size, 1)

        self.linear_layer = nn.Linear(self.hidden_layer_size, self.output_horizon)
        self.loss_function = F.mse_loss
        self.batch_size = batch_size
        for module in self.modules():

    def _output_grads(self, module, grad_input, grad_output):

    def train_dataloader(self):
        return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True,

    def val_dataloader(self):
        return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size, pin_memory=True)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

    def forward(self, input):
        output = self.input_norm(input)
        # Input shape is [batch, sequence, feature] but lstm/gru expects [sequence, batch, feature]
        output = output.permute(1, 0, 2)
        output, _ = self.lstm(output)
        # Extract the hidden layer of the last element of the sequence
        output = output[-1, :, :]
        output = F.relu(output)
        output = self.layer_norm(output)
        output = self.linear_layer(output)
        return output

    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        input, expected = batch

        actual = self(input)
        return self.loss_function(actual, expected)

    def validation_step(self, batch, batch_index) -> Optional[STEP_OUTPUT]:
        input, expected = batch

        actual = self(input)
        return self.loss_function(actual, expected)

def train(dataset: OutdoorTemperatureDataset, learning_rate: float, max_epochs: int,
          hidden_layer_size: int) -> float:
    process_context = ProcessContext(dataset)
    model = process_context.get_model(learning_rate, max_epochs, hidden_layer_size, dataset.batch_size)
    model.learning_rate = learning_rate

    trainer = Trainer(gpus=-1, benchmark=not DETERMINISTIC, precision=16, weights_summary=None,
                      max_epochs=max_epochs, deterministic=DETERMINISTIC, num_sanity_val_steps=0,
    return trainer.logged_metrics["val_loss"]

def main():
        pl.seed_everything(DETERMINISTIC_SEED, workers=True)
        os.putenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
    LEARNING_RATE = 0.7320207424172239
    MAX_EPOCHS = 1000
    batch_size = 1
    dataset = OutdoorTemperatureDataset(batch_size)
    while True:
            train(dataset, LEARNING_RATE, MAX_EPOCHS, HIDDEN_LAYER_SIZE)
        except RuntimeError as e:
            message = repr(e)
            if "CUDNN_STATUS_EXECUTION_FAILED" in message or "CUDA out of memory" in message:
                batch_size = batch_size // 2
                if batch_size <= 0:
                    raise e
                print(f"Reducing batch_size to {batch_size}")
                raise e

if __name__ == "__main__":

And here is the output I get from the backward pass hook:

  (input_norm): LayerNorm((1,), eps=1e-05, elementwise_affine=True)
  (layer_norm): LayerNorm((2,), eps=1e-05, elementwise_affine=True)
  (lstm): LSTM(1, 2)
  (linear_layer): Linear(in_features=2, out_features=2, bias=True)
grad_output=(tensor([[  -inf, 52672.]], device='cuda:0', dtype=torch.float16),)

module=Linear(in_features=2, out_features=2, bias=True)
grad_input=(tensor([[-inf, inf]], device='cuda:0'),)
grad_output=(tensor([[  -inf, 52672.]], device='cuda:0', dtype=torch.float16),)

I’m not sure how to interpret this output… It looks like the first problem is that linear_layer has an input gradient of [[inf, inf]], but where is that input coming from? Does the input come from the loss function, since we’re moving backwards through the layers? Why would F.mse_loss have a gradient of inf?

What should I try next?

PS: Gradient clipping doesn’t actually work in PyTorch Lightning due to Gradient clip norm is called before AMP's unscale leading to wrong gradients · Issue #9330 · PyTorchLightning/pytorch-lightning · GitHub but I don’t believe that’s relevant for this issue because we are failing on the first backward pass, before gradient clipping is even supposed to get invoked.

So, grad_input is the gradient of the loss w.r.t the layer input.

I see you’re using torch.float16 the max value of the bit-width is 6.55 × 10^4, you stated that AMP scales the loss to 600,000. You’re overflowing your loss, and that’s where the Inf is coming from!

  1. Generally speaking, is it normal to end up with inf, nan problems in a network even if the inputs do not contain them? Or does it always indicate there is a bug somewhere in our code?

  2. If it is normal, how are we expected to handle it? Are we expected to add clipping or clamping of values somewhere?

  3. The hook you provided prints out the gradients but not the scales and unscaled values of each layer (the loss, in this case). Is there another hook I can add to do that?

  4. The unscaled loss is around 100 which seems quite reasonable/low. How do I find out why it is being scaled up to such a large value in the first place?

  5. I’m not sure if PyTorch Lightning uses AMP in the backwards pass. How do I find out?


  1. Your NaNs are emerging when calculating the gradient of your loss w.r.t to your parameters, so you won’t see them in your input. You’ll only see them when computing gradients. If your Loss is Inf, the gradients of that loss w.r.t the parameters will be NaN.

  2. Clamping the output to stop it overflow could help, but a simplier solution would be to ask if you really need to be running your code at torch.float16?

  3. The hook prints the gradient that is used during optimizer, so I assume it’s the scaled gradient. (As that’s what AMP uses during backprop).

So if I understand you correctly, even perfectly valid inputs and models can result in Inf, NaN problems. Meaning, this does not necessarily indicate a bug in my code.

Right, except that I’ve run into a similar error in a bigger project that already uses torch.float32 so I’d like to figure out how to clamp the output to protect against such failures. Given:

def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
    input, expected = batch

    actual = self(input)
    loss = self.loss_function(actual, expected)
    return loss

Is this the right way to go about clamping the output?

def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
    input, expected = batch

    actual = self(input)
    loss = self.loss_function(actual, expected)
    limits = torch.finfo(torch.float16)
    return torch.clamp(loss, limits.min, limits.max) # <--- clamp added

I tried this in the testcase but the problem remains.

You’re probably right but I had to step deep into PyTorch’s code to figure out that the loss was being scaled up from 100 to ~600,000. Is there a hook I can register which will print out the layer values so I don’t have to do this in the future? The goal is see all of the weight and gradient values from the debugging hooks.

  1. Walking through the clamped testcase code, I see the loss is equal to 1.4138 at the end of the forward pass and somehow gets scaled up to 92655.7891 right before the backward() gets invoked. Something smells wrong here. Why is such a small loss value (which is being clamped no less!) being scaled up to such a large value that is out of range? I tracked this down to in PyTorch Lightning where I see them scaling the loss up by a factor of 65536. From the looks of things, the value will always overflow… Any idea what could be going on?


First part yes, second part no. Your code producing Infs and NaNs is a bug that results from your code somewhere, it’s just not a result of the inputs.

hmm, that sounds odd given the max value for float32 is around 10^38 . Check there’s no Infs being created as input. You can use a similar method with register_forward_pre_hook (see docs). Perhaps there’s an issue with your LSTM module? (which gets passed to the final Linear layer which has the Infs.

Reading through the pytorch docs on it shows you should use the unscaled gradients, I don’t know how it’s done within PyTorch lightning.
Each parameter’s gradient (.grad attribute) should be unscaled before the optimizer updates the parameters, so the scale factor does not interfere with the learning rate.

Looks good to me

The hooks store Tensors for all inputs samples, if you want the gradients and weights you can print those out with (all though they’ll be reduced to a scalar for all samples, but if there’s a NaN the average over all samples will be NaN too). (So, it’s a simple check).

for name, param in model.named_parameters():
  print(name, param, param.grad)

I don’t know how this works in PyTorch Lightning but see if it’s at all possible to change the scalar_constant that multiple the loss. See if you can reproduce the issue with PyTorch (rather than lightning) via doing AMP in the forward pass only!

FYI, I ended up asking the Lightning guys about the scaler causing overflow and they wrote:

At the beginning of training, in the first few iterations it can happen that the grads become inf or nan. AMP will skip the optimizer.step() in that case and adjust the scale factor iteratively until it converges to a point where that doesn’t happen.

See: AMP scaler always causes backwards pass to overflow · Issue #9799 · PyTorchLightning/pytorch-lightning · GitHub

I hope this helps other people who run across this in the future. So apparently you can’t enable torch.autograd.set_detect_anomaly(True) on first epoch…

1 Like

This solves my issue. Thanks a lot!