Just to be clear on feature extraction

Just wanted to be clear on the pytorch tutorial for feature extraction using a convnet: Transfer Learning for Computer Vision Tutorial — PyTorch Tutorials 1.12.0+cu102 documentation

Am I right to assume that in this set-up, batch norm layers are still active (depending on the model)?

If so, is there a benefit to this?

I have read elsewhere that batch norm layers should also be frozen and thus should use the imagenet global stats, rather than update when doing feature extraction. If this holds true, then why does the tutorial allow for batch norm layers to remain active?

Yes, I think you are right that model.train() is used during the training/fine-tuning.
I guess the author saw benefits in the final accuracy of the fine-tuned model, but you could compare both approaches and check if this holds true.

Could you post the reference you saw as it could also explain by these stats are not trained (maybe the model is more eager to overfit)?

1 Like

Yes, sure. Here is one paper I came across:

Where the authors state the following:

Within the context of fine-tuning a model, the batch normalisation (batch norm) (Ioffe and Szegedy, 2015) layer requires taking a few precautions due to how it operates differently between training and inference. During training, the layer uses the current batch mean and standard deviation to normalise the activations, and, at the same time, it updates exponentially moving averages of the mean and standard deviation and stores them as non-trainable weights to use during inference. While fine-tuning a model, it is usually recommended to use the batch norm layer in inference mode to avoid unexpected or poor performance on the validation and test sets from additional updates during training.

1 Like