Fp16 overflow when computing matmul in autocast context

I’m training a gpt2 model with HuggingFace transformers and noticed I’m running into nan loss values during training. I tracked the source of the nan to a softmax computation where there’s a single inf in the input to the softmax. The inf is coming from a matrix multiply of the query and key matrices to calculate attention weights. Specifically the dot product of two vectors from query/key overflows the float16 dtype. This is all happening within a torch.cuda.amp.autocast context for the forward pass of the model.

I’m still relatively new to amp, but I thought the autocast context was supposed to handle converting between float16 and float32 on the fly in cases like an overflow? If not, I wonder if some clipping would be appropriate to prevent the overflow in this specific setting, but that might be a suggestion for the HF guys :wink:

No, autocast will not check for overflows, but will convert inputs/outputs to lower-precision dtypes for operations which were determined to be safe.
The numerical bounds still apply and if you expect your outputs to have values outside of float16 (~65k) you might need to disable it.

Generally, it would also be interesting to learn more about the use case and why such large values are expected.