Low precision backprop causes exploding gradients

I am trying to run backprop on Mistral-7B. I have access to a 40GB GPU. With 32bit precision, I run out of space as soon as I call backward(). With 16bit precision, I can call backward(), but some of the earlier layer gradients become infinite. I tried running backprop with 32bit precision on a CPU and I can confirm that exploding gradients do not happen with 32bit precision. What can I do?

If you are using mixed precision training then note that overflowing gradients are expected in a few iterations and the gradient scaler will lower its scaling factor to avoid overflows in the next iteration. The docs explain this behavior in more details.
On the other hand, if you are manually calling .half() on the model overflows can easily happen and you might not be able to recover from them. In this case use the mixed precision utils. with float16 or bfloat16.

1 Like

Thank you for your reply. I am not actually using mixed precision: I do everything in float16 (I set it as the default dtype for torch and I load the model and tokenizer in float16 precision to begin with). Does that rule out your suggested explanation?

Yes, since manually casting the model to float16 will compute all intermediates in this dtype, which can easily under- and overflow due to the reduced precision. What’s the reason you are not using the mixed-precision utils.?