`GradScaler`+AMP causes NaN weights on backward pass

TL;DR: when I use AMP GradScaler with two different losses (scaling each one separately), after about 100 epochs, training crashes due to NaN weights on backward.

I am trying to train a self-implemented DC-CDN, which uses two losses (Contrastive Depth Loss and Mean Squared Error).

In my implementation I’ve used autocast for both the forward function and the losses’ computation (in particular, if it helps, I use autocast as an annotator for both of these functions, so as to make sure it is never enabled in another moment during training).

I’ve also used GradScaler, initially summing both losses:

scaler.scale(loss1+loss2).backward()
scaler.step(opt)
scaler.update()

However, as I’ve learned in the AMP Recipe this fits an advanced use case, so I’ve changed the above code to this:

# each loss is scaled separately
scaler.scale(loss1).backward(retain_graph=True)
scaler.scale(loss2).backward()
scaler.step(opt)
scaler.update()

This was after reading this GitHub issue’s discussion.

Disabling GradScaler or autocast (just one, or both) has allowed me to finish my experiments without crashing, but it is my understanding that this could lead to future issues (disabling autocast renders longer training times, which is not ideal, and no gradient scaling could correspond to NaN weights in certain datasets).

Also, the fact that the GradScaler step is not avoiding the NaN weights as it is hints at there something being wrong with my implementation.

The backward pass calculates the gradients and NaNs could be expected if the scaling factor is too high. The weights are not manipulated at all at this point and scaler.step(optimizer) will skip the parameter update if invalid gradients are detected or is this behavior not observed?