Apex amp RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR

Hi, I’m facing with such problem when using apex amp or torch.cuda.amp. But it works well when it standard fp32 training procedure.

Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
Defaults for this optimization level are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
Warning:  multi_tensor_applier fused unscale kernel is unavailable, possibly because apex was installed without --cuda_ext --cpp_ext. Using Python fallback.  Original ImportError was: ImportError('/opt/conda/lib/python3.6/site-packages/amp_C.cpython-36m-x86_64-linux-gnu.so: undefined symbol: _ZN3c105ErrorC1ENS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE',)
  0%|                                                                                                                   | 0/10000 [00:00<?, ?it/s]train.py:261: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with torch.autograd.detect_anomaly():
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
  0%|                                                                                                        | 3/10000 [00:28<33:09:24, 11.94s/it]
  2%|██                                                                                                     | 200/10000 [16:48<7:33:20,  2.78s/it][W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnRnnBackward. Traceback of forward call that caused the error:
  File "train.py", line 346, in <module>
    train(hparams, args)
  File "train.py", line 264, in train
    out = tacotron(batch)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/apex/amp/_initialize.py", line 197, in new_fwd
    **applier(kwargs, input_caster))
  File "/raid/tts/mixed_precision_exp/model_tacotron.py", line 497, in forward
    gst = self.gst(mels)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/raid/tts/mixed_precision_exp/model_tacotron.py", line 61, in forward
    enc_out = self.encoder(inputs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/raid/tts/mixed_precision_exp/model_tacotron.py", line 96, in forward
    memory, out = self.gru(out)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 740, in forward
    self.dropout, self.training, self.bidirectional, self.batch_first)
 (function _print_stack)
Traceback (most recent call last):
  File "train.py", line 346, in <module>
    train(hparams, args)
  File "train.py", line 269, in train
    scaled_loss.backward()
  File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR

Docker image nvcr.io/nvidia/pytorch:20.09-py3 in which I faced with this problem.

Collecting environment information...
PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.14.0

Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: Tesla V100-SXM3-32GB

Nvidia driver version: 450.51.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.4
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] pytorch-transformers==1.1.0
[pip3] torch==1.7.0
[pip3] torchaudio==0.6.0
[pip3] torchtext==0.6.0
[pip3] torchvision==0.8.0a0
[conda] magma-cuda110             2.5.2                         5    local
[conda] mkl                       2019.1                      144  
[conda] mkl-include               2019.1                      144  
[conda] nomkl                     3.0                           0  
[conda] numpy                     1.19.1           py36h30dfecb_0  
[conda] numpy-base                1.19.1           py36h75fe3a5_0  
[conda] pytorch-transformers      1.1.0                    pypi_0    pypi
[conda] torch                     1.7.0                    pypi_0    pypi
[conda] torchaudio                0.6.0                    pypi_0    pypi
[conda] torchtext                 0.6.0                    pypi_0    pypi
[conda] torchvision               0.8.0a0                  pypi_0    pypi

Also, there is no OOM Error and:

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False

self.gru the same as in https://github.com/KinglittleQ/GST-Tacotron/blob/master/GST.py#L43 but in my forward there is no out = F.relu(out) line.
Would you mind giving any advice?
Thanks.

Could you post an executable code snippet to reproduce this error, please?