How can I add an additional output head without losing performance on original head?

Hi,

I am trying to create a resnet18 model with two different output heads. One head predicts the classification class (classification_head) and the other classifies the rotation (rotation_head) of the image.

I am using a pre-trained model for resnet18.

I have frozen all the model parameters using

for param in model_multi.parameters():
	param.requires_grad = False

and I only unfreeze the rotation classification head

for param in model_multi.rotation_head.parameters():
	param.requires_grad = True

Then I train the model to fine-tune the rotation head. The problem is when I swap back to the classification head, there is a drop in the classification accuracy (from 0.85 to 0.79). Theoretically, there should be no drop at all, since the model parameters used for classification of the image class are all the same.

To validate this I have even created a deep copy of the model before fine-tuning for the roation_head to compare the fine-tuned model against. When I calculate the difference between the model parameters of the fine-tuned model and the copied model (i.e. not fine-tuned), I only get a difference in the values for the rotation head whilst all other model parameters do not change.

At this stage, I have no idea why the accuracy for the classification head is changing unless there is something else other than the model parameters that affect classification and I need to freeze it during training of the rotation head. What this other thing be? I would appreciate any help or advice please.

Thanks

I found the reason why I was getting these changes even though I had all model parameters frozen. The running mean and standard deviation in the Batch normalisation layers were still updating and these do not freeze when you set requires_grad to False. To stop the batch normalisation layers from changing one could use the code below after setting the model to model.train().

for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d):
module.eval()

Alternatively, one may save the batch normalisation layer values and reload them as required.

Hope this helps anyone else as well :slight_smile:

1 Like