In the pytorch docs, it is stated that:
torch.amp provides convenience methods for mixed precision, where some operations use the torch.float32 (float) datatype and other operations use lower precision floating point datatype (lower_precision_fp): torch.float16 (half) or torch.bfloat16. Some ops, like linear layers and convolutions, are much faster in lower_precision_fp. Other ops, like reductions, often require the dynamic range of float32. Mixed precision tries to match each op to its appropriate datatype.
Does that mean, linear layer will always be float16?
I’m training a regression model, and final layer is a linear layer. My model predicts inf for value above 65k. Let me know, if I’m missing anything.