As the title goes.
My network classify all samples as the last class it was trained on. It doesnt matter if it trained for 12 hours, 12 minutes, using 100000 samples or 10 samples etc.
It always only classify the last class it was trained on. So if last training loop was a class 0 all classifications are vecome 0.
Is there a known fix to this issue? How to i find what causes this?
How are your training and validation accuracies during training?
Do you set your model to eval using
model.eval() before testing it?
Thanks for the reply!
I set model.train() during training and model.eval() during testing. (I have dropout layer)
I dont calculate any validation accuracies during training. I’ll look into doing that. What insight will validation accuracies give me?
If your model behaves strange during training, the training accuracies should be high, since your model seems to “learn” the last classes, while the validation set should score really low.
I’ve finally tested it and you’re totally right. I get very high training accuracies (89 %) and low validation accuracies (38%).
I think my problem is that I want the network to give me a classification each 5 time samples. Which means that the same class is classified many times in a row for long time series. I can see that the first 10-11 classifications attempts are wrong but then the model learns to classify the correct class and for the remaining part of the time series the class is correct.
However, it resets and start over when a new time series is used. I calculate loss and update optimizer each time it gives a classification. Since it give many classifications for each time series the model quickly learn the single time series but forget everything else.
loss = criterion(output, y)
# Reset gradient
Anything to combat this? Does it make sense to simply call model() and from that calculate the predicted class and not update optimizer and calculate loss?
output = model(nbatch,batch_size)
And then maybe calculate loss and update optimizer at fewer intervals (like every 100 or 1000?)