Training and exporting model weights in float16

I’d like to quantize my model weights to 16 bits for speed/memory savings in deployment. The torch.cuda.AMP package – which appears to be the strong recommendation for training acceleration – returns model weights as 32 bit floats which appear to require a full 32 bits of precision to represent in model saving and loading (additionally, casting them to float16s for inference leads to performance loss). The built-in torch.quantization tools appear to be limited to int8 outputs. Simply calling ‘.half()’ on my model/inputs and attempting to train as normal has the stability issues you might expect. While the documentation is a bit ambiguous about being able to explicitly set the types of particular layers when using AMP, trying to set specific Linear layers to float16 results in the same GradScaler errors as reported elsewhere.

Is there a recommended best practice for training a model ultimately destined for export in 16-bit floating point format?