Training tricks to improve stability of mixed precision

I would love to be able to use automatic mixed precision more extensively in my training, but I find that it is too unstable and often ends in NaNs. Are there any general tricks in training that people here have used to improve stability?

I’ve seen the following general tips:

  • plot the gradients and force unstable layers to fp32
  • bump weight decay in the optimizer
  • bump epsilon in the optimizer
  • try an exotic optimizer
  • add/try different normalization layers
  • force loss calculations to fp32

The mixed-precision training utils. via torch.amp should already cast to the appropriate dtype if the corresponding layer would otherwise suffer from the decreased numerical stability (i.e. in particular via the torch.cuda.amp.autocast context).
I’ve seen that normalization layers were generally beneficial: plain FP32 training “exploded” while mixed-precision training ran into overflows and couldn’t recover.