RuntimeError: Function 'SoftmaxBackward' returned nan values in its 0th output

I’m training tacotron2 (a TTS model) using the seq2seq model with attention.
I use nvidia apex to train the model with mixed precision, and I got the following error:

Traceback (most recent call last):
  File "train.py", line 709, in <module>
    main()
  File "train.py", line 650, in main
    scaled_loss.backward()
  File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 184, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 123, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: (140597717193280,)

This was confusing, so I added torch.autograd.set_detect_anomaly(True) in the code. Then I got:

[W python_anomaly_mode.cpp:60] Warning: Error detected in SoftmaxBackward. Traceback of forward call that caused the error:
  File "train.py", line 709, in <module>
    main()
  File "train.py", line 618, in main
    args)
  File "/workspace/tacotron2/tacotron2/model.py", line 435, in decode
    attention_weights_cat, mask)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "/workspace/tacotron2/tacotron2/model.py", line 113, in forward
    attention_weights = F.softmax(alignment, dim=1)
  File "/opt/conda/lib/python3.6/site-packages/apex/amp/wrap.py", line 28, in wrapper
    return orig_fn(*new_args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py", line 1498, in softmax
    ret = input.softmax(dim)
 (function print_stack)
Traceback (most recent call last):
  File "train.py", line 709, in <module>
    main()
  File "train.py", line 618, in main
    args)
  File "train.py", line 429, in consistency_training_subroutine
    scaled_loss.backward()
  File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 184, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 123, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Function 'SoftmaxBackward' returned nan values in its 0th output.

Looks like the above error happened when computing the gradient of the softmax layer for computing attention weights.
I first tried to print out the input tensor (including max and min value) of softmax, and I got:

tensor([[ 3.7246,  0.2996,  1.9951,  ...,  4.5039,  5.5117,  6.2266],
        [ 2.9258, -0.5249,  0.3879,  ...,  4.8828,  4.8750,  9.7031],
        [-6.9258, -8.2266, -8.1016,  ...,    -inf,    -inf,    -inf],
        ...,
        [-9.0312, -9.9531, -9.3047,  ...,    -inf,    -inf,    -inf],
        [-1.9766, -8.1562, -5.5586,  ...,    -inf,    -inf,    -inf],
        [-8.2656, -5.7539, -6.9531,  ...,    -inf,    -inf,    -inf]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MaskedFillBackward0>)
tensor(12.4922, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)
tensor(-inf, device='cuda:0', dtype=torch.float16, grad_fn=<MinBackward1>)

Here the minimum -inf is filled for the padded positions in the batch.

I think this may because the maximal value float16 can support is 65536, torch.exp(12.4922) will output inf. So I closed the apex but the problem still appeared.

I wonder if there are good methods to solve this problem or help me understand this problem well.
Thank you.

We recommend to use the native mixed-precision training utility via torch.cuda.amp, which will use flaot32 for the softmax operation as given here.