Exploding NLLLoss

I am using the Transformer module provided by the PyTorch for training a model for text generation. I am using NLLLoss() for measuring the quality of reconstruction. After a certain number of iterations, the loss explodes and changes all weights to nan. This is a log generated by the training program.

root - WARNING - Loss: 203.81146240234375
root - WARNING - Loss: 124.32596588134766
root - WARNING - Loss: 62.59440612792969
root - WARNING - Loss: 59.84109115600586
root - WARNING - Loss: 59.247005462646484
root - WARNING - Loss: 48.832725524902344
root - WARNING - Loss: 57.592288970947266
root - WARNING - Loss: 50.18443298339844
root - WARNING - Loss: 46.474849700927734
root - WARNING - Loss: 52.12908172607422
root - WARNING - Loss: 50.090736389160156
root - WARNING - Loss: 66.04253387451172
root - WARNING - Loss: 49.094024658203125
root - WARNING - Loss: 36.69044494628906
root - WARNING - Loss: 48.54591369628906
root - WARNING - Loss: 60.71137237548828
root - WARNING - Loss: 40.35478591918945
root - WARNING - Loss: 49.070556640625
root - WARNING - Loss: 54.33742141723633
root - WARNING - Loss: 47.14014434814453
root - WARNING - Loss: 55.043060302734375
root - WARNING - Loss: 47.63726043701172
root - WARNING - Loss: 46.314571380615234
root - WARNING - Loss: 41.330291748046875
root - WARNING - Loss: 48.85242462158203
root - WARNING - Loss: 50.59345245361328
root - WARNING - Loss: 48.508975982666016
root - WARNING - Loss: 43.35681915283203
root - WARNING - Loss: 45.875431060791016
root - WARNING - Loss: 51.701438903808594
root - WARNING - Loss: 39.1783561706543
root - WARNING - Loss: 30.14274024963379
root - WARNING - Loss: 44.33928680419922
root - WARNING - Loss: 40.88005447387695
root - WARNING - Loss: 62.682804107666016
root - WARNING - Loss: 45.18329620361328
root - WARNING - Loss: 39.7137451171875
root - WARNING - Loss: 47.31813049316406
root - WARNING - Loss: 50.755348205566406
root - WARNING - Loss: 40.52918243408203
root - WARNING - Loss: 49.48160934448242
root - WARNING - Loss: 58.29778289794922
root - WARNING - Loss: 45.660675048828125
root - WARNING - Loss: 55.13115692138672
root - WARNING - Loss: 50.72150421142578
root - WARNING - Loss: 33.377098083496094
root - WARNING - Loss: 48.404151916503906
root - WARNING - Loss: 60.24494934082031
root - WARNING - Loss: 46.290470123291016
root - WARNING - Loss: 9.493173539216099e+24

As you can see, the loss goes down for some time as it should and spikes up. I have tried using gradient clipping to mitigate the issue but it did not solve the problem.

criterion_1 = nn.NLLLoss()
y_hat = model(X_train)
y_hat = y_hat.transpose(0,1)
mask = (tgt!=pad_idx).bool()
y_hat = nn.functional.log_softmax(y_hat, dim = -1)
cel = criterion_1(y_hat.reshape(-1,vocab_size), tgt.reshape(-1))
loss = cel.masked_select(mask.reshape(-1)).sum()
loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), 100)
optimizer.step()

The above given is the code I am using for calculating the loss.

perhaps perfect predictors exist and training reaches (1,0,0,…) state. y_hat = y_hat.clamp(-b,b) should solve that (with b like 10…20, before softmax)

For some reason, clamping the predictions is causing the loss to increase after a certain point. This continues until some of model weights becomes nan.

Actually, I suggested early clamping, and that’s tricky with log_softmax. Post log_softmax clamping (-20.,-1e-6) or an additional loss mask may work instead. Or it is something else, I’d place a breakpoint and inspect problematic network output.

Few things before trying gradient clipping:

  1. What does your input data look like? Make sure it’s in the correct form, you would expect. Sometimes unnormalized input can cause huge loss values.
  2. What optimizer and lr are you using?
  3. Not sure if it’s a good idea to sum the loss values before loss.backward()
  1. Input data is in the shape (batch_size, max_len) and output is in the shape (batch_size, max_len, vocab_size)
  2. I am using Adam optimizer with lr of 0.001
  3. I tried training the model taking the mean loss instead of the sum, I am still getting a spike in the loss.

I tried clamping the output post log softmax, this is the log generated

root - WARNING - Loss: 7753.49169921875
root - WARNING - Loss: 6186.9287109375
root - WARNING - Loss: 5434.07861328125
root - WARNING - Loss: 6422.82568359375
root - WARNING - Loss: 6344.4873046875
root - WARNING - Loss: 5779.78515625
root - WARNING - Loss: 5681.9140625
root - WARNING - Loss: 5288.10498046875
root - WARNING - Loss: 5314.443359375
root - WARNING - Loss: 4506.3115234375
root - WARNING - Loss: 5896.3134765625
root - WARNING - Loss: 6842.0830078125
root - WARNING - Loss: 9111.4599609375
root - WARNING - Loss: 7685.61328125
root - WARNING - Loss: 8802.61328125
root - WARNING - Loss: 11280.5126953125
root - WARNING - Loss: 14238.529296875
root - WARNING - Loss: 13673.314453125
root - WARNING - Loss: 13150.68359375
root - WARNING - Loss: 13360.0
root - WARNING - Loss: 13180.0
root - WARNING - Loss: nan

The NLLLoss becomes nan after few batches

A few things you can check:

  1. Ensure that this output is like what you would expect (ie the scale is same as tgt).
  1. I’m not sure what this part below is doing.

Instead can this variable loss be removed and simply cel.backward() be used.

I usually go for criterion = nn.CrossEntropyLoss() to avoid confusion.

masked_select is for removing the loss corresponding to the <pad> token. I’m building an architecture similar to a variational autoencoder which uses log-likelihood for the loss which is why I used NLLLoss over CrossEntropyLoss

Then it is something else, probably. If you use sampling from trainable distributions, the issue can be there.

Generally autograd.set_detect_anomaly(True) should show the problematic part (NaN inducing). As it slows down training, it is better to enable it late.