torch.amp provides convenience methods for mixed precision, where some operations use the torch.float32 (float) datatype and other operations use lower precision floating point datatype (lower_precision_fp): torch.float16 (half) or torch.bfloat16. Some ops, like linear layers and convolutions, are much faster in lower_precision_fp. Other ops, like reductions, often require the dynamic range of float32. Mixed precision tries to match each op to its appropriate datatype.

Does that mean, linear layer will always be float16?
I’m training a regression model, and final layer is a linear layer. My model predicts inf for value above 65k. Let me know, if I’m missing anything.

You can wrap the evaluation of the last linear layer in an autocast(enabled=False) block, it is briefly touched upon at the bottom of the description of the autocast context manager.

Thanks for the input.
Does that mean, linear layer will always be float16?
I’m training a regression model, and final layer is a linear layer. My model predicts inf for value above 65k.

By default, linear layers will use fp16 computation. The context manager allows you to deviate from that by turning off AMP for certain bits. For example, you could make a “LinearFP32” subclass of Linear by calling the super forward inside the with block mentioned above.