Nan values in loss after a few epochs

I was training Swin transformers using SimMIM using Huggingface’s implementation and have been using a custom SimMIM implementation.
After the warmup epochs, the losses either go to a fixed value and stay there, with no scope for convergence (equal predictions for all classes on the downstream task), or go to Nan. I’ve implemented gradient clipping and am using a small learning rate (1e-4) but these still haven’t fixed the issues. With a larger learning rate (8e-4), the losses diverge, then go to the fixed value mentioned earlier. This doesn’t happen with the ViT models, which converge without issues.
Can anyone please suggest how I can fix this? Thanks in advance.

Finally fixed this issue. Turns out, switching to FP32 from FP16 autocasting helped resolve the issues with convergence.

Thank you for sharing your solution.
I’m encountering similar issues with SwinV2 using the original SimMIM repository. On my custom dataset, I’m experiencing NaN values in the loss after approximately 8 epochs. Despite my efforts, I haven’t yet found a successful solution to this problem.

In my case, I’ve been using Automatic Mixed Precision (AMP) from the beginning.
While switching to FP32 could potentially help, it would be ideal to find a solution that maintains the memory and speed benefits of mixed precision.

Has anyone else experienced stability or convergence issues with the combination of SimMIM and Swin Transformers? Are these problems widely known or reported?

Additionally, has anyone successfully trained this combination on larger datasets or for extended periods without encountering NaN losses or convergence problems?

I had resorted to a layer-wise learning rate decay and training using BF16 at the end. I think the issues are due to overflow caused by a large LR and limited precision with FP16, which BF16 manages to avoid to a large extent.

FP16 provides a higher precision than BF16 at the cost of a narrower range.

Sorry for that. I meant the range. The gradients were exploding, so switching to BF16 helped them stay within limits.

I’ve been exploring different approaches to address these stability issues while maintaining the benefits of mixed precision. Here are two strategies I’ve implemented:

  1. More aggressive gradient clipping: I reduced the clip-grad-norm from 5 to 1, which seems to be working.
  2. Precision switching: I started training with bfloat16 for the initial phases where the learning rate is higher. This approach, as you mentioned, helps avoid overflow issues. Later in the training process, when the learning rate is much smaller and the network has mostly converged, I switch to float16.

These methods have shown some promise in improving stability while still leveraging the memory and speed advantages of mixed precision training.

Both experiments are running at the moment and looking good so far.

Your thoughts on these strategies and any further recommendations would be helpful.