I’m training a gpt2 model with HuggingFace transformers and noticed I’m running into nan
loss values during training. I tracked the source of the nan
to a softmax
computation where there’s a single inf
in the input to the softmax
. The inf
is coming from a matrix multiply of the query and key matrices to calculate attention weights. Specifically the dot product of two vectors from query/key overflows the float16
dtype. This is all happening within a torch.cuda.amp.autocast
context for the forward pass of the model.
I’m still relatively new to amp
, but I thought the autocast
context was supposed to handle converting between float16
and float32
on the fly in cases like an overflow? If not, I wonder if some clipping would be appropriate to prevent the overflow in this specific setting, but that might be a suggestion for the HF guys