I am currently working with PyTorchs native mixed precision, and I am uncertain on how some elements are handled udner the hood. I tried reading the scripts in torch.cuda.amp, but they did not answer my questions.
How does PyTorch handle casting of specifc modules, such as Batch Norm, under the hood? I know that NVIDIA’s APEX library has different levels of mixed precision (see: https://nvidia.github.io/apex/amp.html) where Batch Norm can be explictely stated to be float32.
Is this also the case for PyTorch native amp, or does it simply convert everything to Flaot16. If so, how would I go about converting my batch norm layers to Float32? This also pertains to how to properly handle custom normalization layers such as Inplace Activated BatchNorm (https://github.com/mapillary/inplace_abn) which should be in Float32 for the best performance.
Hope you cna help me figure this out.