My model (with many transformer layers) is noticablly sensitive to numerical precisions. Therefore, I am trying to find out particularly sensitive points in my model. I am not getting errors.
So, I was wondering if forcing LayerNorms to compute in float32 might be make my model more robust to numerical errors. I wanted to know if such a practice (force using float32 for LayerNorm only) is common, or, such a concer is completely irrelevant.
While I can’t speak for your particular case, lower precision during training generally tends to result in an additional level of model regularization (i.e. to prevent overfitting) and is desirable for this(among other reasons, such as reduced VRAM usage).
With that said, if a model were trained on a higher precision, lower precision during inference time may result in lower accuracy and other performance metrics than the precision it was trained on.
However, if I’m not mistaken, a model in .eval() during inference has NormLayers and other types of regularization layers turned off.
I didn’t know that using low precision helps regularization. I will stick to using it.
Based on what you suggested, it seems that forcing LayerNorm to use float32 even during mixed precision training (or, inference) is not a common practice. Thank you for your comments.
Actually, I created this post after looking at the source code of pytorch’s RMSNorm.
Inside this code, they first cast input to float32, and then they cast it back to the original precision at the end, perhaps for better precision of normalization.
Do you (or someone) know if such a practice is also disired for LayerNorm if a model is sensitive to numerical errors?