Unpredictable nan losses during training

I keep getting nan losses during training in a very unpredictable way, after the first one all the parameters in the model become nan, forcing me to stop the training and start again.
I noticed that when the length of the Dataloader is bigger i.e. the Dataset im using is larger, the problem seems to start earlier, when i use a smaller dataset everything works as expected. I don’t understand how the length of the Dataloader can affect the training behavior, given that nothing else has changed.
I checked the output of the Dataloader multiple times and found no problem, so my guess is that at some point in the training one of the parameters in the model changes to nan than propagates through the model. The network has Conv, GRU, Relu and Sigmoid layers, using the AdamW optimizer with weight decay. I can’t share my code here because it’s kind of huge.

what causes such problems usually ? and how should i try solving it ?

my only idea was to reset the training to the latest completed epoch if nan is detected, it works but it’s not efficient at all, and sometimes causes the training to get stuck in a loop.

To my knowledge, when you have the loss as NaN, it means your loss diverges a lot to larger values. I would assume this might be caused by GRU layers.

This is related to the length of the Dataloader because you aren’t updating the learning rate. I would assume this is happening in the first few epochs of training.

Having a simple code snippet of your network, train procedure, and some context of your current task would be helpful to suggest some fixes.