Limiting precision of specific buffers

Similar to: Half precision: ignoring a buffer,
I have a model that uses a custom module with several specific buffers which I wish to limit to at least single precision, as half precision is not enough for them. Other than those, the entire model can be converted to half.
What is the best approach? Currently, I’m considering overriding the _apply method in my custom module, but this feels a bit risky and inelegant. Any suggestions?

I assume you are calling .to() directly on your model, which is then transforming all registered parameters and buffers to the desired dtype. If so, you could transform or re-initialize the buffers again as float32 tensors. Note that explicitly transforming tensors to torch.float16 is not the recommended way and you should use the autocast approach which will keep all buffers and parameters in float32.

Yes I am using .half() on the entire model, which I assume is the same as using .to().
Is autocast also the recommended approach for inference, or saving a smaller-sized model? because this is where I encounter accuracy loss.
I am transforming the buffers back to float32 during training, but it will not help to prevent losing precision when saving or loading the model for fp16 inference.

Yes, autocast can also be used during inference and won’t cause any accuracy issues assuming the training worked fine since the same parameters and buffers (in float32) will be used. Manually transforming the parameters and buffers to float16 for inference could work, but you would need to verify it based on your use case.

1 Like