How to find causes of NAN grads?

After upgrade to PyTorch 1.12.1 (conda, cuda 11.3), I got nan grad for some Conv2d weights and biases right after the validation:

- Epoch 0 training.
- Validation.
- Epoch 1 training.
- Validation.
...
- Epoch N training.
- Validation.
- Got nan in the first step of epoch N+1.
- Got nan in the second step of epoch N+2.
- Got nan in the third step of epoch N+3.
...

If I save the model’s state dict and inputs when nan appears, then run forward and backward manually in another python process with the saved objects, the grad will be normal with no nan.

I use model.to(memory_format=torch.channels_last) and fp16 training. If I remove model.to(memory_format=torch.channels_last) or use fp32 training, the training will work normally without nan grad.

It also works in PyTorch 1.11.

Is there any suggestion to find causes of random nan grads?

Hi @mikasa,

Have a look at torch.autograd.detect_analomy which is designed to find what operations cause NaNs. (docs here)

Thank you @AlphaBetaGamma96. I have tried torch.autograd.detect_anomaly, but it just said ConvolutionBackward0 returns nan values.

The convolution will not get nan values if I run the model in another python process with the same states and inputs. It also works normally if I use PyTorch 1.11 or use Pytorch 1.12 with data format NCHW.

This isn’t referencing the output of your convolution layer, it’s referencing the gradients of the convolution layer for its 0-th output.

Could you share the model you’re using? As well as the full error message you get when running torch.autograd.detect_anomlay because if you have more than 1 convolution layer it’ll be harder to know which one is causing the NaN issue (although it’s most likely the last one).

It seems that I can’t post the full model and trace here.
I set the breakpoint at the line containing the conv with nan gradient and run the following code:

# Conv2d.
print(module)  # Conv2d(32, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0), bias=False) 
print(module.weight.stride()) # (256, 1, 32, 32)
print(module.weight.dtype) # torch.float32
print(module.weight.isnan().any()) # torch(False, device='cuda:0')

# Conv2d's input.
print(input.shape) # torch.Size([48, 32, 64, 376])
print(input.stride()) # (770048, 1, 12032, 32)
print(input.dtype)  # torch.float16
print(input.isnan().any())  # torch(False, device='cuda:0')

# Ok.
print(torch.autograd.grad(module(input)[0, 0, 0, 0], input)[0].isnan().any()) # torch(False, device='cuda:0')
# Why does `* 0.0` get nan grad?
print(torch.autograd.grad(module(input)[0, 0, 0, 0] * 0.0, input)[0].isnan().any()) # torch(True, device='cuda:0')

I’m not sure why the last line (and the full conv backward) gets nan grad (in PyTorch 1.12.1) after running some validation passes.