I am using the Transformer module provided by the PyTorch for training a model for text generation. I am using NLLLoss()
for measuring the quality of reconstruction. After a certain number of iterations, the loss explodes and changes all weights to nan
. This is a log generated by the training program.
root - WARNING - Loss: 203.81146240234375
root - WARNING - Loss: 124.32596588134766
root - WARNING - Loss: 62.59440612792969
root - WARNING - Loss: 59.84109115600586
root - WARNING - Loss: 59.247005462646484
root - WARNING - Loss: 48.832725524902344
root - WARNING - Loss: 57.592288970947266
root - WARNING - Loss: 50.18443298339844
root - WARNING - Loss: 46.474849700927734
root - WARNING - Loss: 52.12908172607422
root - WARNING - Loss: 50.090736389160156
root - WARNING - Loss: 66.04253387451172
root - WARNING - Loss: 49.094024658203125
root - WARNING - Loss: 36.69044494628906
root - WARNING - Loss: 48.54591369628906
root - WARNING - Loss: 60.71137237548828
root - WARNING - Loss: 40.35478591918945
root - WARNING - Loss: 49.070556640625
root - WARNING - Loss: 54.33742141723633
root - WARNING - Loss: 47.14014434814453
root - WARNING - Loss: 55.043060302734375
root - WARNING - Loss: 47.63726043701172
root - WARNING - Loss: 46.314571380615234
root - WARNING - Loss: 41.330291748046875
root - WARNING - Loss: 48.85242462158203
root - WARNING - Loss: 50.59345245361328
root - WARNING - Loss: 48.508975982666016
root - WARNING - Loss: 43.35681915283203
root - WARNING - Loss: 45.875431060791016
root - WARNING - Loss: 51.701438903808594
root - WARNING - Loss: 39.1783561706543
root - WARNING - Loss: 30.14274024963379
root - WARNING - Loss: 44.33928680419922
root - WARNING - Loss: 40.88005447387695
root - WARNING - Loss: 62.682804107666016
root - WARNING - Loss: 45.18329620361328
root - WARNING - Loss: 39.7137451171875
root - WARNING - Loss: 47.31813049316406
root - WARNING - Loss: 50.755348205566406
root - WARNING - Loss: 40.52918243408203
root - WARNING - Loss: 49.48160934448242
root - WARNING - Loss: 58.29778289794922
root - WARNING - Loss: 45.660675048828125
root - WARNING - Loss: 55.13115692138672
root - WARNING - Loss: 50.72150421142578
root - WARNING - Loss: 33.377098083496094
root - WARNING - Loss: 48.404151916503906
root - WARNING - Loss: 60.24494934082031
root - WARNING - Loss: 46.290470123291016
root - WARNING - Loss: 9.493173539216099e+24
As you can see, the loss goes down for some time as it should and spikes up. I have tried using gradient clipping to mitigate the issue but it did not solve the problem.
criterion_1 = nn.NLLLoss()
y_hat = model(X_train)
y_hat = y_hat.transpose(0,1)
mask = (tgt!=pad_idx).bool()
y_hat = nn.functional.log_softmax(y_hat, dim = -1)
cel = criterion_1(y_hat.reshape(-1,vocab_size), tgt.reshape(-1))
loss = cel.masked_select(mask.reshape(-1)).sum()
loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), 100)
optimizer.step()
The above given is the code I am using for calculating the loss.