I am solving a multiclass classification problem and I am using a pretrained resnet101 for that. I encounter a problem with the behavior during training and validation. I start by training only the fully-connected layer and then gradually ‘unfreeze’ all others. The loss on validation steadily decreases.
But when I ‘unfreeze’ the whole network, the losses on both training and validation start to increase after each epoch.
What could be a possible reason for that?
It could be because the network is so out of tune with your problem that the gradients become so big and then suffer from gradient explosion. I don’t think it’s a big problem when you use resnets in general due to batch norm.
One way to combat this is to have a burn in period of some sorts. You can do what you did, where you unfreeze the network layer by layer. Another way would be to start with a really low learning rate and then increase it for a while until the network has adapted -> then you continue with your normal learning rate schedule.
Thank you for your advice, I will check the gradients.