I have a code right now that seems to run really well, but it have one problem - sometimes during training it returns nan.
Here is a snippet from my latest logfile, showing the mean loss of each epoch:
0 Loss(train): 5.21e-01 Loss(val): 3.66e-01
1 Loss(train): 2.61e-01 Loss(val): 2.09e-01
2 Loss(train): 1.91e-01 Loss(val): 1.85e-01
3 Loss(train): 1.56e-01 Loss(val): 1.54e-01
4 Loss(train): 1.33e-01 Loss(val): 1.15e-01
5 Loss(train): 1.19e-01 Loss(val): 1.13e-01
6 Loss(train): 1.07e-01 Loss(val): 1.11e-01
7 Loss(train): 9.75e-02 Loss(val): 9.25e-02
8 Loss(train): 9.14e-02 Loss(val): 1.00e-01
9 Loss(train): 8.57e-02 Loss(val): 8.18e-02
10 Loss(train): 7.88e-02 Loss(val): 7.34e-02
11 Loss(train): 7.45e-02 Loss(val): 7.97e-02
12 Loss(train): 7.18e-02 Loss(val): 6.37e-02
13 Loss(train): 6.75e-02 Loss(val): 6.64e-02
14 Loss(train): 6.79e-02 Loss(val): 6.55e-02
15 Loss(train): 2.12e-01 Loss(val): 6.11e-01
16 Loss(train): 3.95e-01 Loss(val): 7.29e-02
17 Loss(train): 6.69e-02 Loss(val): 5.58e-02
18 Loss(train): 5.95e-02 Loss(val): 5.19e-02
19 Loss(train): nan Loss(val): nan
20 Loss(train): nan Loss(val): nan
21 Loss(train): nan Loss(val): nan
22 Loss(train): nan Loss(val): nan
As far as I’m aware the network seems to be training decently and seems to be on the right track, and then suddenly it just goes to nan, it seems to be preventable by lowering the learning rate, but I’m hessitant to do this since it seems like the learning rate is decently set for the problem right now otherwise.
I have tried to prevent it so far by using:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0e-2, norm_type=2.0)
optimizer.step()
but that doesn’t seem to work, I have also tried to use:
torch.autograd.set_detect_anomaly(True)
but no problems are showing up in the first could of iterations (I haven’t let it run until it hits the problem yet since it slows down the code quite a bit and in the past any errors have been caught in the initial iteration with detect_anomaly.
My next attempt will be to run the code until error with detect_anomaly, but that will take quite a while, are there any other standard approaches that I should attempt in order to figure out why my code suddenly goes nan in the middle of training?