Does torch.cuda.amp support O2 almost FP16 training now?

Hello,

When I use apex amp, it always shows a warning:
Warning: apex.amp is deprecated and will be removed by the end of February 2023. Use [PyTorch AMP](Automatic Mixed Precision package - torch.amp — PyTorch 2.1 documentation
Thus, I am trying to update my code (from apex to pytorch build-in).
When I used apex library, my code adopts “O2” mixed precision, aka almost FP16 level.
However, I did not see anything similar in the pytorch build-in module torch.autocase.
There are only two types: torch.float16 and torch.bfloat16.
Does this mean that pytorch did not implement “O2” mixed precision? Why?

On the other hand, I am also curious about the training speed.
torch.float16 > torch.bfloat16 > O1 > O0 > O3 > O2.
Why O2 is the slowest?
Note: O0: FP32, O1: Mixed Precision; O2: almost FP16 ; O3: FP16

Thank you very much!

Yes, that’s correct as its deprecated apex implementation was too limiting in its flexibility and disabled a few important use cases. There are workarounds using a custom optimizer now holding the states.

1 Like