Understanding PyTorch native mixed precision

HI all,

I am currently working with PyTorchs native mixed precision, and I am uncertain on how some elements are handled udner the hood. I tried reading the scripts in torch.cuda.amp, but they did not answer my questions.

How does PyTorch handle casting of specifc modules, such as Batch Norm, under the hood? I know that NVIDIA’s APEX library has different levels of mixed precision (see: https://nvidia.github.io/apex/amp.html) where Batch Norm can be explictely stated to be float32.

Is this also the case for PyTorch native amp, or does it simply convert everything to Flaot16. If so, how would I go about converting my batch norm layers to Float32? This also pertains to how to properly handle custom normalization layers such as Inplace Activated BatchNorm (https://github.com/mapillary/inplace_abn) which should be in Float32 for the best performance.

Hope you cna help me figure this out.

Best regards,
Joakim

PyTorch has a list of operations that can autocast to fp16. The operations not listed here will remain in fp32. Batch normalization will stay in fp32 when you use amp.autocast().

PyTorch native amp is similar to apex level O1. A more detailed explanation of @mcarilli can be found here.

Hi @seungjun, thank you very much for your response, it answer my question fully :smiley:.