How to disable AMP in some layers?

Hi everyone, I want to disable AMP for all BatchNorm2d layers in my models because running_var is prone to cause overflow when converting from float32 to float16. How can I implement this idea?

You can use nested context managers via with torch.cuda.amp.autocast and disable it for certain layers.