Training with mixed precision: loss is NaN despite finite output in forward pass

When training a BERT-like model on my custom dataset using PyTorch’s built-int automatic mixed precision, I encountered an issue that I have been unable to resolve despite a lot of effort.

My model uses the following components (taken from timm):

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

Training went well for a few hours before the loss became NaN, so to find out what happened, I wrapped the training loop under the detect_anomaly context, like this:

with torch.autograd.detect_anomaly():
    train_stats = train_one_epoch(model,...)

Then I obtained the following message:

RuntimeError: Function ‘SoftmaxBackward0’ returned nan values in its 0th output

together with a nice traceback (immense thanks to @albanD for having implemented torch.autograd.detect_anomaly to save the world an infinite amount of debugging time), which points me to the line attn = attn.softmax(dim=-1) in the Attention forward method. There was potentially an overflow when applying the softmax, so I tried letting it use full precision:

        # borrowed from https://github.com/huggingface/transformers/pull/18057/files
        if attn.dtype == torch.float16:
            attn = attn.softmax(dim=-1, dtype=torch.float32).to(torch.float16)
        else:
            attn = attn.softmax(dim=-1)

Then, the backward pass no longer has an issue at the softmax, but this time:

RuntimeError: Function ‘NativeDropoutBackward0’ returned nan values in its 0th output.

The trace back points to the next line, attn = self.attn_drop(attn). Perhaps after softmax, there was still an issue, so I try adding an extra layer of clamping:

        # borrowed from https://github.com/huggingface/transformers/pull/18057/files
        if attn.dtype == torch.float16:
            attn = attn.softmax(dim=-1, dtype=torch.float32).to(torch.float16)
            # clamp inf values to enable fp16 training
            # https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
            if torch.isinf(attn).any():
                clamp_value = torch.finfo(attn.dtype).max - 1000
                attn = torch.clamp(attn, min=-clamp_value, max=clamp_value)
        else:
            attn = attn.softmax(dim=-1)

Then I obtained:

RuntimeError: Function ‘BmmBackward0’ returned nan values in its 1th output.

but this time with a rather strange traceback:

Epoch: [0]  [ 120/2669]  eta: 0:50:30  lr: 0.000017  loss: 6.8861 (6.9013)  time: 1.1663  data: 0.0002  max mem: 12126
Epoch: [0]  [ 140/2669]  eta: 0:50:00  lr: 0.000020  loss: 6.8800 (6.8984)  time: 1.1699  data: 0.0002  max mem: 12126
Traceback (most recent call last):
  File "/home/code/train_bert.py", line 367, in <module>
    main(args)
  File "/home/code/train_bert.py", line 301, in main
    train_stats = train_one_epoch(
  File "/home/code/engine.py", line 68, in train_one_epoch
    loss_scaler(loss, optimizer, clip_grad=max_norm,
  File "/home/code/util/misc.py", line 290, in __call__
    self._scaler.scale(loss).backward(create_graph=create_graph)
  File "/home/.conda/envs/cuda11/lib/python3.9/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/.conda/envs/cuda11/lib/python3.9/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'BmmBackward0' returned nan values in its 1th output.

The error happened when calling scale(loss).backward(), but I have no idea where BmmBackward0 comes from (@albanD, is this a bug? The traceback doesn’t point to the line of code where Bmm was applied). In the Attention forward pass, there is a matrix multiplication: attn @ v, so I thought it would come from here, and I tried applying the same clamping above to the variable v, but this did not help.

I hope somebody here could help me to resolve this issue, or at least to find out where the issue comes from. Thank you very much in advance!

For more information, I have checked the intermediate values in the Attention forward pass and there was no issue, all of them were finite.

For example, I tried replacing the line x = (attn @ v).transpose(1, 2).reshape(B, N, C) with:

        x = attn @ v
        if torch.any(~torch.isfinite(x)):
            print(x)
            raise Exception('attn @ v has undefined values')
        x = x.transpose(1, 2).reshape(B, N, C)

and there was no error.

autocast will use float32 in softmax layers already so your manual casting shouldn’t help.
Note that some iterations are expected to create invalid gradients e.g. if the loss scaling factor is too large. In this case the scaler.step call will skip the optimizer.step() operation and will reduce the scaling factor in its scaler.update() call.
Using detect_anomaly would trigger errors even for these expected skipped iterations.

1 Like

Thanks, @ptrblck, for your reply!

autocast will use float32 in softmax layers already so your manual casting shouldn’t help.

Good to know, thanks! But, if softmax was already using float32, then why did manually casting it to float32 (and then back to float16) moved the error to the next layer (as shown above)? I’m a bit confused…

Note that some iterations are expected to create invalid gradients e.g. if the loss scaling factor is too large. In this case the scaler.step call will skip the optimizer.step() operation and will reduce the scaling factor in its scaler.update() call.
Using detect_anomaly would trigger errors even for these expected skipped iterations.

Thanks for the information. I think I will need to read carefully the documentation on torch.cuda.amp to fully understand this.


I thought there was a bug in the trace back of torch.autograd.detect_anomaly I turn out to be wrong, as pointed out by @albanD in this GitHub issue. Previously I disabled warnings when launching my training so the full trace back wasn’t shown. Enabling the warnings, I can now see that the error, with no surprise, moved to the next layer:

x = (attn @ v).transpose(1, 2).reshape(B, N, C)

But this time I did not manage to find a solution. Clamping v did not help. Would you have an idea on how to resolve this please? Thank you so much for your help!