I’m trying to train a PyTorch model on an A100 with bfloat16 mixed precision.
torch AMP maintains a list of operations deemed numerically unstable in float16 (Automatic Mixed Precision package - torch.cuda.amp — PyTorch 1.10.1 documentation) and automatically casts the inputs to those operations to float32, even if you’re running AMP with bfloat16. Is there a way to globally tweak that list to prevent inputs to certain operations from being cast to full precision without manually enabling/disabling AMP every time they’re called? I’d like e.g. all softmax calls throughout my model to run using bfloat16.