Train loss and test loss increasing after +/- 600 epochs


I am running a multichannel UNET on MRI images of different sequences.
After +/- 600 epoch I get these losses (blue = train loss, orange = test loss) :

Why does the losses increase (a lot) after 600 epochs ? What should I do ?
And I think that my model overfits, the gap between train loss and test loss tells us…, how can I reduce this gap ?

Some params :

  • loss function = dice loss
  • optimizer = adam optim
  • batch size = 10
  • learning rate = 1e-4

Hi Nestlee!

It looks like your model / training is becoming unstable, which is something
that can happen. Note that the Adam optimizer can be a little finicky and
unstable, even though it often trains faster when it’s working.

First, try using Adam's weight_decay parameter (if you aren’t already).
Depending on the cause of your instability, this might very well help.

Your using UNET suggests that you are performing semantic segmentation,
which is a per-pixel classification problem.

I would recommend not using pure dice loss. I would suggest starting with
pure CrossEntropyLoss – the go-to loss criterion for classification – with
class weights (the weight argument), if you have class imbalances. If you
want to use dice loss, I would suggest augmenting CrossEntropyLoss
with dice loss.

As mentioned above, consider using weight_decay.

If you still have stability problems, consider switching to the SGD
optimizer, most likely with momentum and probably with weight_decay.
SGD tends to be better behaved, even if it doesn’t train as fast.

Note that the learning rate (the lr argument) doesn’t map directly
from Adam to SGD. You will likely want to use a significantly higher
value for lr when using SGD.

(Unlikely to be the cause of your problem.)

Consider reducing the learning rate significantly – perhaps by one or two
(or more) orders of magnitude – before the instability sets in – perhaps
around epoch 500.

It’s not clear that your model it overfitting, per se. Both your train and test
losses go down nicely before they both plateau. It may be that you are
doing as well as you can, given your specific model and the data you have.

However, sometimes training does plateau for quite a while, but then
starts getting better again as you keep training. So, if you can avoid the
instability, you might be able to train (much) longer, and get past your
(long-lasting, but temporary) plateau.

As a general rule, to do better on your test set, it helps to train with more
training data. Of course, you might not have more data to train on, but
you can often get significant benefit by using data augmentation to “derive”
some additional partially “fake” data from your “real” training data.

Good luck.

K. Frank