Override AMP casting during bfloat16 training

I’m trying to train a PyTorch model on an A100 with bfloat16 mixed precision.

torch AMP maintains a list of operations deemed numerically unstable in float16 (Automatic Mixed Precision package - torch.cuda.amp — PyTorch 1.10.1 documentation) and automatically casts the inputs to those operations to float32, even if you’re running AMP with bfloat16. Is there a way to globally tweak that list to prevent inputs to certain operations from being cast to full precision without manually enabling/disabling AMP every time they’re called? I’d like e.g. all softmax calls throughout my model to run using bfloat16.

I’m unsure if there is an easy way to “tune” the autocast list (@mcarilli might know), but for now you could disable autocast for the desired operations, transform your data to bfloat16() and call the operations directly.

1 Like