Best practice to use LayerNorm with reduced precision

I want to know how people are using LayerNorm with reduced precisions (float16, bfloat16) .

Main questions are:

  • Do I need to cast inputs back to float32 for accurate layer norm computation?
  • Does LayerNorm casts inputs with reduced precisions to float32 automatically?

Thank you .

What problem are you trying to solve by using higher precision in this layer?

Normally the LayerNorm operation is handled in whatever precision you have designated your model parameters and inputs as.

Are you getting an error otherwise?

Thank you for clarification, Mr Johnson.

My model (with many transformer layers) is noticablly sensitive to numerical precisions. Therefore, I am trying to find out particularly sensitive points in my model. I am not getting errors.
So, I was wondering if forcing LayerNorms to compute in float32 might be make my model more robust to numerical errors. I wanted to know if such a practice (force using float32 for LayerNorm only) is common, or, such a concer is completely irrelevant.

Thanks again,

While I can’t speak for your particular case, lower precision during training generally tends to result in an additional level of model regularization (i.e. to prevent overfitting) and is desirable for this(among other reasons, such as reduced VRAM usage).

With that said, if a model were trained on a higher precision, lower precision during inference time may result in lower accuracy and other performance metrics than the precision it was trained on.

However, if I’m not mistaken, a model in .eval() during inference has NormLayers and other types of regularization layers turned off.

1 Like

I didn’t know that using low precision helps regularization. I will stick to using it.

Based on what you suggested, it seems that forcing LayerNorm to use float32 even during mixed precision training (or, inference) is not a common practice. Thank you for your comments.

Actually, I created this post after looking at the source code of pytorch’s RMSNorm.
Inside this code, they first cast input to float32, and then they cast it back to the original precision at the end, perhaps for better precision of normalization.

Do you (or someone) know if such a practice is also disired for LayerNorm if a model is sensitive to numerical errors?

Thanks again!