Cannot freeze batch normalization parameters

during training my model i am making some of the layers not trainable via:

for param in model.parameters():
        param.requires_grad = False

however after checking the parameters i see there are a lot of parameters that still train and change such as:


after searching a lot i notice these are batch norm parameters.
How can i freez them or in other word make them requires_grad =False

Following is a toy example that shows requires_grad = False wont work correctly:

import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import os

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.convs = nn.ModuleList([ nn.Conv2d(3,6,3),
                                    nn.Conv2d(6, 10, 3),
                                    nn.Conv2d(10, 10, 3) ])
    self.fcs = nn.Sequential(nn.Linear(320, 10),
                                 nn.Linear(10, 5),
                                 nn.Linear(5, 1))

    def forward(self, x):
        x = self.convs[0](x)
        x = self.convs[1](x)
        x = self.convs[2](x)
        x = self.convs[3](x)
        x = x.view(-1,)
#        print(x.size())
        x = self.fcs(x)
        return x

model = Net()

loss = nn.L1Loss()
target = Variable(torch.ones(1))

for name, param in model.named_parameters():
    if name == 'convs.0.bias' or name=='fcs.2.weight':
        param.requires_grad = True
        param.requires_grad = False

old_state_dict = {}
for key in model.state_dict():
    old_state_dict[key] = model.state_dict()[key].clone()
optimizer = optim.SGD(filter(lambda p: p.requires_grad,model.parameters()),  lr=0.001)

for epoch in range(5):
    X = Variable(torch.rand(2,3,10,10))
    out = model(X)
    output = loss(out, target)
new_state_dict = {}
for key in model.state_dict():
    new_state_dict[key] = model.state_dict()[key].clone()
# Compare params
count = 0
for key in old_state_dict:
    if not (old_state_dict[key] == new_state_dict[key]).all():
        print('Diff in {}'.format(key))
        count += 1

out put:

dict_keys(['convs.0.weight', 'convs.0.bias', 'convs.1.weight', 'convs.1.bias', 'convs.1.running_mean', 'convs.1.running_var', 'convs.2.weight', 'convs.2.bias', 'convs.3.weight', 'convs.3.bias', 'fcs.0.weight', 'fcs.0.bias', 'fcs.2.weight', 'fcs.2.bias', 'fcs.4.weight', 'fcs.4.bias'])

Diff in convs.1.running_mean
Diff in convs.1.running_var
Diff in fcs.2.weight

I was dealing with that ***** the whole day, finally i think i got it, adding this will make BN not trainable:

        def set_bn_eval(m):
            classname = m.__class__.__name__
            if classname.find('BatchNorm2d') != -1:

In the default settings nn.BatchNorm will have affine trainable parameters (gamma and beta in the original paper or weight and bias in PyTorch) as well as running estimates.
If you don’t want to use the batch statistics and update the running estimates, but instead use the running stats, you should call m.eval() as shown in your example.
However, this won’t disable the gradients for weight and bias!
If you don’t want to train them at all, you can just specify affine=False. Otherwise you should treat them as trainable parameters.


are you saying if im using a pretrained model and want to train my model via the pretrained weights from that model and set affine = False it’s gonna consider them not trainable and keep them whatever they are based on the pretrained model and wont update them?

No, sorry for the misunderstanding.
affine will just be considered during the instantiation of the model.
If the nn.BatchNorm layers were already created using affine=True, both parameters will be in the model, and you should treat them as other parameters, i.e. setting requires_grad=False if you don’t want to train them further.

1 Like

Got it! the model that i had was trained with affine=true and thats why it has running_mean and running_var, and apparently requires_grad =False wont work to make it not trainable (or at least it didnot work for me) so I had to make them to be in eval mod


It means that if the nn.BatchNorm layers were created using affine=False, theirs weight and bias ( beta and gamma ) will not update and as not trainable parameters, right? Dose afine=False equal to requires_grad=False or torch.no_grad()?

And mode.eval() is for setting nn.BatchNorm layers do not update the running estimates.
If I am wrong, please correct me.

Thanks in advance.

If affine was set to False these parameters are set to None as shown in this line of code.

Yes, model.eval() will not update the running stats and instead apply them.


Is it a huge problem if these are calculated during inference (whilst keeping all other layers frozen for feature extraction)?

If you are updating the running stats during inference, you could “leak” this inference dataset, which might be bad.
I.e. if this inference dataset is the validation or test dataset, you would then use it to “train” the model by updating the running stats of the batchnorm layers. If you call model.eval() and use these running stats afterwards on the validation/test dataset, I would see it as a data leak and would claim the model performance is biased in their favor.

1 Like

Say you extract features at only a separate inference that is directly from an pretrained ImageNet model.

For example, you extract features using ImageNet, determine the mean and std deviation from the test set mini batches and feed this feature vector into an already trained classifier layer that has not “seen” the test data.

I know this is odd, but would it be considered a data leak? I’m not sure if this is a data leak but rather perhaps a way the data normalises itself into the ImageNet feature extraction layers.

Perhaps this is wasted computation for a dataset that is not “seen” at all…

# Define feature extractor 
class DenseNetConv(torch.nn.Module):
    def __init__(self):
        original_model = models.densenet161(pretrained=True)
        self.features = torch.nn.Sequential(*list(original_model.children())[:-1])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = F.relu(x, inplace=True)
        x = F.avg_pool2d(x, kernel_size=7).view(x.size(0), -1)
        return x

densenet = DenseNetConv()

classifier = nn.Linear(2208, args.num_classes)
bestmodel = get_best_model(best)
bestmodel = torch.load(bestmodel)

# only classifier dense layer is set to eval()

with torch.no_grad():
        for i,(inputs10x , labels) in enumerate(dataloaders_dict['test']):

            inputs10x =
            test_labels =
           # extract features 
            x10 = densenet(inputs10x)
           # feed extracted features into previously trained classifier 
            # Forward pass to get output/logits
            outputs =  classifier(x10)
            _, pred = torch.max(outputs, 1) 

Yes, I think it’s a data leak as you are changing the model based on information of the test data.
Even if you would not directly change model values (in this case you would update the running stats using the test data) you can still leak information from the test set e.g. by selecting the best performing model.

I would see one exception: if you never plan to use these running stats, i.e. if you never call .eval() on the batch norm layers and are thus ignoring these stats, you might not be too concerned about this leak.

1 Like

Ah ok many thanks for the detailed reply.

Yes, as seen with the code written above, the densenet is never called to eval() in training or at inference time. The only model set to eval() is the “classifier” as seen above.

However, it came to my attention because the results are different to when you set the densenet to eval() at both training and inference (I am using the densenet for feature extraction- it’s a strange way I know but there are plenty of ways of doing transfer learning from what I have seen!).

For example, the probabilities for the predicted values are not as extreme (either close to 0 or close to 1) when setting the densenet to eval() in comparison to when the densenet its left in train() for both training and inference. Do you have a hypothesis on why this might happen?

If the outputs between the training and validation runs differ a lot I would assume that e.g. the normalization layers didn’t learn the underlying data distribution, which could be caused by e.g.:

  • a small training batch size, which creates noisy stats
  • a difference in the data distribution between the training and validation datasets

You might want to play around with the momentum argument of the batchnorm layers to try to smooth the stats updates.
If you don’t want to call eval() on the norm layers and always want to use the batch stats to normalize the activations, note that this normalization approach would depend on the actual batch size. This would also mean that your model might predict differently during inference depending on the used batch size, which is usually not wanted.

1 Like

Yes, thank you, for this approach I do make sure that the batch size remains the same to ensure the same behaviour is applied at inference!

My only wondering thought that remains is why do I get different results if I apply .eval() to the densenet extractor for feature extraction for training and inference. This is where I essentially use the running stats predetermined by ImageNet, as the batch norm layers are also frozen in this way. The predictions are generally more confident this way too. I would have to apply an ROC curve (or similar) to get similar results for this method (densenet.eval() in train and inference) and with the method described above (densenet.train() in train and inference)…

Perhaps it is how you said due to the small training batch size or due to the momentum argument…

I don’t fully understand this claim as you’ve previously mentioned that eval() is never called so the running stats would be updated during the entire training. If you call eval() afterwards these running stats will now be used to normalize the input activations.

1 Like

Apologies for not making it clearer.

eval() in the above set up is only called when training the classifier above which is a linear model with one layer (denoted as classifier hence classifier.eval()).

The input to this model are features extracted from the densenet that consequently has batch norm still running. However as specified above, this densenet never has eval() directly applied to it.

Basically, the densenet model as described above is a feature extractor where eval() is never called for this.

Alternatively, in another separate experiment, I have called eval() with the densenet prior to feature extraction.

For example, in this particular separate set up for training:

densenet= DenseNetConv()
densenet.eval() # apply eval here too freeze batch norm layers

classifier = nn.Linear(2208, args.num_classes)

for epoch in range(1, args.num_epochs):

        for i, (inputs10x, labels) in enumerate(dataloaders_dict['train']):

            inputs10x =
            labels =

            x10 = densenet(inputs10x) #extract features

           # feed extracted features into classifier 
            # Forward pass to get output/logits
            outputs =  classifier(x10)

And thus for inference (test):

densenet = DenseNetConv()
densenet.eval() # now we set densenet to eval() similar to training

classifier = nn.Linear(2208, args.num_classes)
bestmodel = get_best_model(best)
bestmodel = torch.load(bestmodel)

with torch.no_grad():
        for i,(inputs10x , labels) in enumerate(dataloaders_dict['test']):

            inputs10x =
            test_labels =
           # extract features 
            x10 = densenet(inputs10x)
           # feed extracted features into previously trained classifier 
            # Forward pass to get output/logits
            outputs =  classifier(x10)
            _, pred = torch.max(outputs, 1) 

I speak of the differences in results between this set up described here and the prior set up described earlier where I don’t apply eval() to the densenet feature extractor.