When training a BERT-like model on my custom dataset using PyTorch’s built-int automatic mixed precision, I encountered an issue that I have been unable to resolve despite a lot of effort.
My model uses the following components (taken from timm
):
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
Training went well for a few hours before the loss became NaN, so to find out what happened, I wrapped the training loop under the detect_anomaly
context, like this:
with torch.autograd.detect_anomaly():
train_stats = train_one_epoch(model,...)
Then I obtained the following message:
RuntimeError: Function ‘SoftmaxBackward0’ returned nan values in its 0th output
together with a nice traceback (immense thanks to @albanD for having implemented torch.autograd.detect_anomaly
to save the world an infinite amount of debugging time), which points me to the line attn = attn.softmax(dim=-1)
in the Attention
forward method. There was potentially an overflow when applying the softmax, so I tried letting it use full precision:
# borrowed from https://github.com/huggingface/transformers/pull/18057/files
if attn.dtype == torch.float16:
attn = attn.softmax(dim=-1, dtype=torch.float32).to(torch.float16)
else:
attn = attn.softmax(dim=-1)
Then, the backward pass no longer has an issue at the softmax, but this time:
RuntimeError: Function ‘NativeDropoutBackward0’ returned nan values in its 0th output.
The trace back points to the next line, attn = self.attn_drop(attn)
. Perhaps after softmax, there was still an issue, so I try adding an extra layer of clamping:
# borrowed from https://github.com/huggingface/transformers/pull/18057/files
if attn.dtype == torch.float16:
attn = attn.softmax(dim=-1, dtype=torch.float32).to(torch.float16)
# clamp inf values to enable fp16 training
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
if torch.isinf(attn).any():
clamp_value = torch.finfo(attn.dtype).max - 1000
attn = torch.clamp(attn, min=-clamp_value, max=clamp_value)
else:
attn = attn.softmax(dim=-1)
Then I obtained:
RuntimeError: Function ‘BmmBackward0’ returned nan values in its 1th output.
but this time with a rather strange traceback:
Epoch: [0] [ 120/2669] eta: 0:50:30 lr: 0.000017 loss: 6.8861 (6.9013) time: 1.1663 data: 0.0002 max mem: 12126
Epoch: [0] [ 140/2669] eta: 0:50:00 lr: 0.000020 loss: 6.8800 (6.8984) time: 1.1699 data: 0.0002 max mem: 12126
Traceback (most recent call last):
File "/home/code/train_bert.py", line 367, in <module>
main(args)
File "/home/code/train_bert.py", line 301, in main
train_stats = train_one_epoch(
File "/home/code/engine.py", line 68, in train_one_epoch
loss_scaler(loss, optimizer, clip_grad=max_norm,
File "/home/code/util/misc.py", line 290, in __call__
self._scaler.scale(loss).backward(create_graph=create_graph)
File "/home/.conda/envs/cuda11/lib/python3.9/site-packages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/.conda/envs/cuda11/lib/python3.9/site-packages/torch/autograd/__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'BmmBackward0' returned nan values in its 1th output.
The error happened when calling scale(loss).backward()
, but I have no idea where BmmBackward0
comes from (@albanD, is this a bug? The traceback doesn’t point to the line of code where Bmm
was applied). In the Attention
forward pass, there is a matrix multiplication: attn @ v
, so I thought it would come from here, and I tried applying the same clamping above to the variable v
, but this did not help.
I hope somebody here could help me to resolve this issue, or at least to find out where the issue comes from. Thank you very much in advance!