Reasons and Inspection for NAN loss during Seq2Seq training?

I have a Seq2Seq model that I’ve been training and working on for the past few months. Recently I did a couple of modifications in the data module loading and learning rate warm up etc. I also created a new dataset of 10million training pairs. During training, the model seems to go to NAN suddenly in the 3rd or 2nd epoch. This is really confusing since the loss doesn’t go high up or down. It just becomes NAN suddenly and in the 3rd epoch which means that it successfully went over the whole dataset Atleast once!

To troubleshoot I tried training my new model with my previous dataset and it seems to be training perfectly on the old dataset! I checked my new(large) dataset and removed all nan and null values from the dataset. What else can I check in the dataset to make sure it does not crash? It’s an NLP dataset of text sentence pairs. Like Machine Translation. I’m using the BERTTokenizerFast for tokenization. Since there are 10million rows, it’s kinda impossible to read through it manually.

I’m using pytorch-lightning with ddp-sharded and 16-bit precision for training in all the experiments.

16 bit precision has a tendency to truncate values to 0 more often than 32 bit. this might be triggering a inf which might lead to a downwind nan. you can trigger a breakpoint by

  1. checking the output of the network (i.e. end of forward pass) for inf or nan. and stopping if you encounter it
  2. checking the first layer gradients for nan/inf ( i.e. end of the backward pass), and stopping the code (before optimizer.step is actually called and ends up corrupting your weights) if this happens

then you can inspect the conditions around which the nans arise ( was it in the forward or backward, had something to do with data/ float 16?)