Use of FP16 in backward() causes NaN in gradients only when create_graph = True

I am using pytorch within a chatbot training routine and I would like to get FP16’s advantages in GPU memory/speed. My routine seems to work fine using FP32. In switching it to FP16, my problem appears to be caused by the loss.backward() pytorch routine. For my optimizer to work, I need to use the argument create_graph = True in backward. When I do that with the model I am working with (for any scale factor on the loss that I’ve tried), backward gives NaN for some of the parameter gradients while others appear to have reasonable values. However, when I use the argument create_graph = False to see how these gradients compare with the former case, there doesn’t appear to be any NaN generated and the gradients that weren’t NaN in the former case agree with what is computed here. Since the create_graph is supposed to allow computation of higher order derivatives (but not the first derivative), I am confused why the create_graph = True argument is having this effect. Any suggestions you can provide would be appreciated!

Could you post a minimal, executable code snippet to reproduce the issue, please?
Also, the output of python -m torch.utils.collect_env would be needed.

The code is a fine tuning job for a chatbot with over 400M parameters (BlenderBot2, see ParlAI/projects/blenderbot2 at main · facebookresearch/ParlAI · GitHub) so it is not amenable to creating a code snippet to show the issue. Is there another way to provide useful information?

Here is the env info…
(ParlAI) C:\Users\Steve>python -m torch.utils.collect_env
Collecting environment information…
PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Home
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.19.1
Libc version: N/A

Python version: 3.8 (64-bit runtime)
Python platform: Windows-10-10.0.19042-SP0
Is CUDA available: True
CUDA runtime version: 11.2.142
GPU models and configuration: GPU 0: GeForce RTX 2070 SUPER
Nvidia driver version: Could not collect
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\bin\cudnn_ops_train64_8.dll
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.3
[pip3] pytorch-ranger==0.1.1
[pip3] torch==1.9.0
[pip3] torch-optimizer==0.1.0
[pip3] torchtext==0.10.0
[conda] blas 2.106 mkl conda-forge
[conda] cudatoolkit 11.1.1 heb2d755_7 conda-forge
[conda] libblas 3.9.0 6_mkl conda-forge
[conda] libcblas 3.9.0 6_mkl conda-forge
[conda] liblapack 3.9.0 6_mkl conda-forge
[conda] liblapacke 3.9.0 6_mkl conda-forge
[conda] mkl 2020.4 hb70f87d_311 conda-forge
[conda] numpy 1.21.1 py38h09042cb_0 conda-forge
[conda] pytorch 1.9.0 py3.8_cuda11.1_cudnn8_0 pytorch
[conda] pytorch-ranger 0.1.1 pypi_0 pypi
[conda] torch 1.9.0 pypi_0 pypi
[conda] torch-optimizer 0.1.0 pypi_0 pypi
[conda] torchtext 0.10.0 pypi_0 pypi

Could you post a code snippet to see how you’ve initialized the model and post random input shapes, which we could use to try to reproduce the issue instead?