I am solving a multiclass classification problem and I am using a pretrained resnet101 for that. I encounter a problem with the behavior during training and validation. I start by training only the fully-connected layer and then gradually ‘unfreeze’ all others. The loss on validation steadily decreases.
But when I ‘unfreeze’ the whole network, the losses on both training and validation start to increase after each epoch.
It could be because the network is so out of tune with your problem that the gradients become so big and then suffer from gradient explosion. I don’t think it’s a big problem when you use resnets in general due to batch norm.
One way to combat this is to have a burn in period of some sorts. You can do what you did, where you unfreeze the network layer by layer. Another way would be to start with a really low learning rate and then increase it for a while until the network has adapted -> then you continue with your normal learning rate schedule.