Werid Training Loss Curve

Hi, I am training a simple LSTM for regression task and got following training loss curve.

While training, I logged the MSE computed on whole dataset at each epoch. The weird thing is that there is a flat region in the training curve. It seems that my network got stuck and learned nothing from epoch 2 to epoch 10. Then it got a ‘Aha moment’ at around epoch 11 and the loss dropped rapidly.

This really puzzled me. Any idea why this happened?

For code, please check here.

Your model might get into a saddle point, where the gradient is pretty low.
After epoch 10 is could escape the saddle point and gained some speed.

This is a good illustration:

Taken from: http://www.offconvex.org/2016/03/22/saddlepoints/

1 Like

Yeah, this makes sense. But I tried to repeat the training process several times and kept getting exactly similar ‘flat region’. Since dataset are randomly shuffled by DataLoader, it shouldn’t get into the saddle point every time, right?
Or is this phenomenon caused by some intrinsic nature of the model / dataset used?

From my experience, this is actually pretty typical behaviour for an LSTM training curve. My intuition for this phenomena is as follows:

  1. the LSTM learns to predict the “average” signal by setting the bias parameters (first plateau)
  2. the LSTM figures out how to use the (forward and recurrent) weight parameters to actually solve the task (second plateau)

This can be interpreted as a saddle point (where the function is much steeper in bias dimensions than in the other dimensions).