LayerNorm's grads become NaN after first epoch

  1. Your NaNs are emerging when calculating the gradient of your loss w.r.t to your parameters, so you won’t see them in your input. You’ll only see them when computing gradients. If your Loss is Inf, the gradients of that loss w.r.t the parameters will be NaN.

  2. Clamping the output to stop it overflow could help, but a simplier solution would be to ask if you really need to be running your code at torch.float16?

  3. The hook prints the gradient that is used during optimizer, so I assume it’s the scaled gradient. (As that’s what AMP uses during backprop).