Loss spikes late in the training process

My model is given by:

class MyNetQ(nn.Module):
    def __init__(self):
        super(MyNetQ, self).__init__()
        self.myp=0.5
        self.flatten = nn.Flatten()
        self.linear_LeakyReLU_stack = nn.Sequential(
            nn.Linear(8*68, 8*60),
            nn.LeakyReLU(),
            nn.Dropout(p=self.myp),
            nn.Linear(8*60, 8*50),
            nn.LeakyReLU(),
            nn.Dropout(p=self.myp),
            nn.Linear(8*50, 8*40),
            nn.LeakyReLU(),
            nn.Dropout(p=self.myp),
            nn.Linear(8*40, 8*30),
            nn.LeakyReLU(),
            nn.Dropout(p=self.myp),
            nn.Linear(8*30, 8*20),
            nn.LeakyReLU(),
            nn.Dropout(p=self.myp),
            nn.Linear(8*20, 8*15),
            nn.LeakyReLU(),
            nn.Linear(8*15, 11)
        )
    def forward(self, x):
        x = self.flatten(x)
        out = self.linear_LeakyReLU_stack(x)
        return out

During the training process, the soft cross entropy loss exhibits huge spikes after seemingly having made good progress:

What could be the reason for this behavior? Is there any way I could make the training process more stable?

A hint, anyone? I suspect there may be a flaw in how I set up the model, but I don’t see it.

I think your model is overfitting. What’s your validation phase behavior?

The blue line in the plot is the average training loss during each epoch, and the orange line is the average validation loss for each epoch. Validation does not use the dropout layers while training does use them. Actually, since training and validation loss are so close together at all times, I don’t think overfitting is an issue. There must be something else going on?

My bad. I didn’t see the legends. Did you experiment with learning rate decay? Maybe after some epochs, your lr is too high to find optimal minima.

I am using optimizer = torch.optim.Adam(mymodel.parameters(), lr=0.003). Good point, I’ll try to experiment with different learning rates, thanks for the suggestion.

1 Like

Use a Learning Rate Schedular. Should help with convergence.

1 Like