Cannot freeze batch normalization parameters

during training my model i am making some of the layers not trainable via:

for param in model.parameters():
        param.requires_grad = False

however after checking the parameters i see there are a lot of parameters that still train and change such as:


extras.0.conv.7.running_var
extras.1.conv.1.running_mean
extras.1.conv.1.running_var
extras.1.conv.4.running_mean
extras.1.conv.4.running_var
extras.1.conv.7.running_mean
extras.1.conv.7.running_var
extras.2.conv.1.running_mean
extras.2.conv.1.running_var
extras.2.conv.4.running_mean
extras.2.conv.4.running_var
extras.2.conv.7.running_mean
extras.2.conv.7.running_var
extras.3.conv.1.running_mean
extras.3.conv.1.running_var
extras.3.conv.4.running_mean
extras.3.conv.4.running_var
extras.3.conv.7.running_mean
extras.3.conv.7.running_var

after searching a lot i notice these are batch norm parameters.
How can i freez them or in other word make them requires_grad =False

Following is a toy example that shows requires_grad = False wont work correctly:


import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import os

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        
        self.convs = nn.ModuleList([ nn.Conv2d(3,6,3),
                                    nn.BatchNorm2d(6),
                                    nn.Conv2d(6, 10, 3),
                                    nn.Conv2d(10, 10, 3) ])
        
    self.fcs = nn.Sequential(nn.Linear(320, 10),
                                 nn.ReLU(),
                                 nn.Linear(10, 5),
                                 nn.ReLU(),
                                 nn.Linear(5, 1))



    def forward(self, x):
        x = self.convs[0](x)
        x = self.convs[1](x)
        x = self.convs[2](x)
        x = self.convs[3](x)
        x = x.view(-1,)
#        print(x.size())
        x = self.fcs(x)
        
        return x






model = Net()

loss = nn.L1Loss()
target = Variable(torch.ones(1))

for name, param in model.named_parameters():
    if name == 'convs.0.bias' or name=='fcs.2.weight':
        param.requires_grad = True
    else:
        param.requires_grad = False

     
old_state_dict = {}
for key in model.state_dict():
    old_state_dict[key] = model.state_dict()[key].clone()
print(old_state_dict.keys())
    
optimizer = optim.SGD(filter(lambda p: p.requires_grad,model.parameters()),  lr=0.001)

for epoch in range(5):
    
    X = Variable(torch.rand(2,3,10,10))
    out = model(X)
    output = loss(out, target)
    output.backward()
    optimizer.step()
    
new_state_dict = {}
for key in model.state_dict():
    new_state_dict[key] = model.state_dict()[key].clone()
            
# Compare params
count = 0
for key in old_state_dict:
    if not (old_state_dict[key] == new_state_dict[key]).all():
        print('Diff in {}'.format(key))
        count += 1
print(count)

out put:

dict_keys(['convs.0.weight', 'convs.0.bias', 'convs.1.weight', 'convs.1.bias', 'convs.1.running_mean', 'convs.1.running_var', 'convs.2.weight', 'convs.2.bias', 'convs.3.weight', 'convs.3.bias', 'fcs.0.weight', 'fcs.0.bias', 'fcs.2.weight', 'fcs.2.bias', 'fcs.4.weight', 'fcs.4.bias'])

Diff in convs.1.running_mean
Diff in convs.1.running_var
Diff in fcs.2.weight
3

I was dealing with that ***** the whole day, finally i think i got it, adding this will make BN not trainable:


        def set_bn_eval(m):
            classname = m.__class__.__name__
            if classname.find('BatchNorm2d') != -1:
              m.eval()
    
        model.apply(set_bn_eval)
2 Likes

In the default settings nn.BatchNorm will have affine trainable parameters (gamma and beta in the original paper or weight and bias in PyTorch) as well as running estimates.
If you don’t want to use the batch statistics and update the running estimates, but instead use the running stats, you should call m.eval() as shown in your example.
However, this won’t disable the gradients for weight and bias!
If you don’t want to train them at all, you can just specify affine=False. Otherwise you should treat them as trainable parameters.

4 Likes

are you saying if im using a pretrained model and want to train my model via the pretrained weights from that model and set affine = False it’s gonna consider them not trainable and keep them whatever they are based on the pretrained model and wont update them?

No, sorry for the misunderstanding.
affine will just be considered during the instantiation of the model.
If the nn.BatchNorm layers were already created using affine=True, both parameters will be in the model, and you should treat them as other parameters, i.e. setting requires_grad=False if you don’t want to train them further.

1 Like

Got it! the model that i had was trained with affine=true and thats why it has running_mean and running_var, and apparently requires_grad =False wont work to make it not trainable (or at least it didnot work for me) so I had to make them to be in eval mod

Hi,

It means that if the nn.BatchNorm layers were created using affine=False, theirs weight and bias ( beta and gamma ) will not update and as not trainable parameters, right? Dose afine=False equal to requires_grad=False or torch.no_grad()?

And mode.eval() is for setting nn.BatchNorm layers do not update the running estimates.
If I am wrong, please correct me.

Thanks in advance.

If affine was set to False these parameters are set to None as shown in this line of code.

Yes, model.eval() will not update the running stats and instead apply them.

2 Likes