How BatchNorm layer behave when we use pre-trained model to extract features?


(jasonhan) #1

Thank you for your time.
I’m trying to use pre-trained resnet18 to extract features from my own images for classification task.
At first, I do exactly what they do in the official transfer learning tutorial and use the model as an feature extractor ,witch works fine but quite slow in my case(pretty large amount of data and poor hardware).
I thought it will save me a lot of time to calculate the features ( I choose the input of the final fully connection layer as feature) just once and reuse it in every epoch.
So I try to break the whole procedure into two parts:
1.extract the feature(the fully connection layer input) once.
2.use that feature as future model input.
However, it doesn’t work well.
I think the reason is lying between the difference of training mode and testing mode.
When I use the approach in the tutorial (compute the features every epoch), between training time and testing time, the training mode need to be turning on and off:

model.train(True)
or
model.train(False)

In resnet ,I think it will tell the batch normalization layers to record he mean and variance of each activation layer or apply them on test set.

But when I use the second approach, I don’t know if I should turn off the training mode or not.
I’ve tried to turn off the training mode for all my training set and testing set, which implies the fact that I’m not training the model but treat all my data as testing set to extract features, but when I use the features extracted in this way, something went wrong.(training well but testing accuracy bouncing up and down ).
This is the code I build the extractor from the pre-trained model:

new_classifier = nn.Sequential(*list(pre_trained_model.children())[:-1])
for param in new_classifier.parameters():
param.requires_grad = False
new_classifier.train(False)

I don’t know if there is some bugs in my code or this is the wrong approach to do anyway.
I really want to know when we use a model with batch normalization layer as a feature extractor, how does the mean and variance of each activation layer calculate, dose it simply reuse the value from the data witch the model was trained on, like imageNet?
Do I need to turn on and off the training mode between our training and testing set when I extract features?
It seems that few people discussed this before.
Sorry for my poor english, hope you can understand me.
Any help appreciated, thanks ahead!!


(Simon Wang) #2

I think what you did should be basically fine. When a pretrained model is used as a feature extractor, you should always use the moving mean/var because the model is considered fixed. You should also make sure that your input data is normalized to similar input distribution that the pretrained model is trained with.


(jasonhan) #3

Thank you for your answer.


(jasonhan) #4

Here is what my understanding right now.
In official turorial
which we turn on the training mode in training phase:

    for phase in ['train', 'val']:
        if phase == 'train':
            scheduler.step()
            model.train(True)  # Set model to training mode
        else:
            model.train(False)  # Set model to evaluate mode

This means even we only want to train the final layer and freeze every layer before, we still need to let BatchNorm layer get to work during training and calculate the mean/var for our own training set.
Then when the model is trained, we turn off the training mode and test it and apply the mean/var from training set to test data.

So, if we choose to calculate the fully connect layer input (or whatever layer feature to be extracted) once and save it for reuse in every epoch. The right thing to do should be turn on the training mode to calculate the mean/var for all our data (train and test). In this case, the value of the extracted features should be equal to the tutorial approach (more easy to understand, but a little computational unefficient).
Otherwise, if we turn off the training mode to extract features, I don’t where we can get whose mean/var(s) for BatchNorm layer.