Adam+Half Precision = NaNs?

Hi guys,

I’ve been running into the sudden appearance of NaNs when I attempt to train using Adam and Half (float16) precision; my nets train just fine on half precision with SGD+nesterov momentum, and they train just fine with single precision (float32) and Adam, but switching them over to half seems to cause numerical instability. I’ve fiddled with the hyperparams a bit; upping epsilon helps a tiny bit but doesn’t fix the issue.

Is this something anyone else has info on? If not I can throw together a reproduction script and dig into the issue.

Thanks again! Been a good while since I’ve had to post on account of hitting no issues otherwise.

1 Like

half precision is super finicky during training, so I’m not surprised.

One thing I recommend trying is to do the forward + backward in half precision, but the optimizer step in float precision.
To do this, you might have to clone your parameters, and cast them to float32 and once forward+backward is over, you copy over the param .data and .grad into this float32 copy (and call optimizer.step on this float32 copy) and then copy back…

Other than that, I dont have a good idea of why adam + half is giving NaNs.

2 Likes

Thanks, that’s just the answer I was looking for–will try out the precision swaps and report back.

EDIT: see below, it looks like eps was the culprit after all, no need for this solution.

Alright, got this working by just hanging onto fp32 copies of the parameters and keeping all of the Adam values in fp32 as well, as shown in This Gist. I suspect you could get the desired stability and do this even more efficiently by just keeping the Adam values in fp32 (I think there’s a divide-by-0 happening somewhere) but this gives the desired memory reduction without any loss of speed over fp32.

1 Like

It’s probably a 0 division somewhere. Have you tried using a much larger eps (say 1e-4)? The default 1e-8 is rounded to 0 in half precision.

17 Likes

I had previously tried upping epsilon after tagging it as the culprit, but I can’t recall to exactly what values–as of right now I’m training with eps=1e-4 and it’s working just fine. Guess I should have dug into that further, thanks!

3 Likes

You might want to look at this paper:
https://arxiv.org/abs/1609.07061
if you’re willing to keep a copy of weights/gradients in FP32 you might be able to reduce the precision of forward/backward step much further than FP16.

1 Like

For anybody who arrives here through a google search -
This is a paper by Nvidia which sheds more light on training in FP16.
https://arxiv.org/abs/1710.03740

Also they have been nice to provide implemented code -

This is a very old post, and google search took me here. For future references, it seems Adam is now adapted to do half-precision training with tuning on hyperparameters:

I had a similar issue. Changing BCELoss to BCEWithLogitsLoss and Adam epsilon from 10**(-8) to 10**(-4) worked for me.

I found it useful to inspect the computed gradients when debugging.

print([p.grad for p in model.parameters()])

For me the backward function died before the optimizer was even called. My loss was very high and dividing the loss by 1e5 as last step before the backward pass helped to get rid of the NaN. Of course, performance could suffer, learning rate may need to be tuned etc.

I had this problem with TensorFlow, but this is the only post I found discussing it, so I might as well share my solution. Float16s can only represent numbers as small as 10e-5, but the default adam epsilons for TensorFlow and PyTorch are lower than this (10e-7 and 10e-8 respectively). This seems to cause underflow errors when using float16s. Changing the epsilon to 10e-4 solved the problem for me.

1 Like

Wow, thank you so much !!:slight_smile: Your solution works really well !!!

1 Like

Another solution, use bitsandbytes. It has Adam8bit optimizer.

import bitsandbytes as bnb

# adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer
adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer

Thank you very much :slight_smile: