Unexplainable NaN loss

I am currently trying to train a triplet-based variational autoencoder.

I used 5-fold cross-validation during which I encountered instabilities during training only with the last two splits for the largest model configuration. Otherwise, training worked well.

Now I am trying to train my models on the full dataset. During training I report the mean training loss after every epoch. From the first epoch on, the reported loss is NaN. However, when checking the mean loss of every individual batch, there are usually no NaN values at all for the first epoch although the final reported loss is always NaN.

Afterwards, usually within the second epoch all loss values also within individual batches become NaN after which the model never recovers. Interestingly, I never encounter just a few NaNs in a batch. There are no NaNs at first (as described above) after which there are only NaNs.

I have already tried decreasing the learning rate and clipping the gradients to no avail. I also checked almost every intermediate tensor for NaN values but the behaviour is the same as described, everything suddenly becomes NaN at a certain point. I attempted to find the culprit by using torch.autograd.set_detect_anomaly(True) which simply reported that a MulBackwards call received NaN as its first input. However, that didn’t get me anywhere.

I’m out of ideas. Do you have any ideas or suggestions?

1 Like

Can you create minimal code that repeats the issue?