How to freeze BN layers while training the rest of network (mean and var wont freeze)

I have a network that consists of batch normalization (BN) layers and other layers (convolution, FC, dropout, etc)
I was wondering how we can do the following :

  1. I want to freeze all the layer and just train the BN layers
  2. freeze the BN layers and train every other layer in the network except BN layers

My main issue is how to handle freezing and training the BN layers

1 Like

You can do something like this, hope this works for you.

model = Net()
for name ,child in (model.named_children()):
    if name.find('BatchNorm') != -1:
        for param in child.parameters():
            param.requires_grad = True
    else:
        for param in child.parameters():
            param.requires_grad = False 
1 Like

thank you for your reply, but the bn.running_mean and bn.running_var values of the BN still change when I freeze them

The running stats are not parameters (they do not get gradients and the optimizer won’t update them), but buffers, which will be updated during the forward pass.
You could call .eval() on the batchnorm layers to apply the running stats instead of calculating and using the batch statistics and updating the running stats (which would be the case during .train()).

4 Likes

thanks for your reply. can you please give me a simple snippet code on how I can train the network while freezing the BN layer?

@Usama_Hasan’s code snippet could be used to freeze the affine parameters of all batchnorm layers, but I would rather use isinstance(child, nn.BatchNorm2d) instead of the name check.
Anyway, his code should also work. :slight_smile:

2 Likes

Yes, @ptrblck thanks for the advice, it’s a more pythonic and stable solution.

1 Like

usually, I use model.train() when training, but now I don’t wanna train the BN during the training. does it sounds right if in training I do model.eval() and then do the following


for name ,child in (model.named_children()):
    if name.find('BatchNorm') != -1:
        for param in child.parameters():
            param.requires_grad = False
    else:
        for param in child.parameters():
            param.requires_grad = True

The code looks alright, but again I would recommend to use isinstance instead of the name check. :wink:
Note that model.eval() will also disable dropout layers (any maybe change the behavior of custom layers), so you might want to call eval() only on the batchnorm layers.

3 Likes

oh I see,
Didn’t know I can call eval() just on BN layers
Will give it a try with eval on BN layers and will use isinstance
thanks for clarifying

1 Like

here is what you want to do.

You dont need to use both hasattr and eval together. But I just did it to be safe.

for module in model.modules():
    # print(module)
    if isinstance(module, nn.BatchNorm2d):
        if hasattr(module, 'weight'):
            module.weight.requires_grad_(False)
        if hasattr(module, 'bias'):
            module.bias.requires_grad_(False)
        module.eval()
7 Likes

@ptrblck I use module.eval() to freeze the stats of an instancenorm layer (with track_running_stats = True) after some epochs. However, the running_mean and ranning_var are still updating over training. Any idea why this may happen?

I cannot reproduce this issue using this code snippet and the running stats are equal after calling norm.eval():

norm = nn.InstanceNorm2d(num_features=3, track_running_stats=True)
print(norm.running_mean, norm.running_var)
> tensor([0., 0., 0.]) tensor([1., 1., 1.])

x = torch.randn(2, 3, 24, 24)

out = norm(x)
print(norm.running_mean, norm.running_var)
tensor([-0.0029,  0.0005,  0.0003]) tensor([0.9988, 1.0021, 0.9980])

out = norm(x)
print(norm.running_mean, norm.running_var)
> tensor([-0.0056,  0.0010,  0.0006]) tensor([0.9978, 1.0040, 0.9962])

norm.eval()
out = norm(x)
print(norm.running_mean, norm.running_var)
> tensor([-0.0056,  0.0010,  0.0006]) tensor([0.9978, 1.0040, 0.9962])

Are you using an older PyTorch version, where this could have been a known issue?

I use version 1.6 and the issue arise when I switch to eval() during the training. I have a situation like the following:

def _freeze_norm_stats(net):
    try:
        for m in net.modules():
           
            if isinstance(m, InstanceNorm1d):
                
                m.track_running_stats = False
                m.eval()

    except ValueError:  
        print("errrrrrrrrrrrrrroooooooorrrrrrrrrrrr with instancenorm")
        return

model = Model().cuda()

for epoch in range(start_epoch,end epoch):
    #for epochs < 10, the model updates the running_mean and running_var of its instancenorm layer
    if epoch == 10:
        print("freeze stats:")
        model.apply(_freeze_norm_stats)
    #for epochs >= 10, the model should continue with fixed fixed mean and var for its instancenorm layer

Unfortunately, printing the running_mean and running_mean after calling the function _freeze_norm_stats shows that they’re still being updated over training steps. Maybe I have something wrong in the function implementation but I am still not able to figure this out.

Could you update to the latest stable release (1.7.1) and check, if you are still facing this issue?

removing this line and only keeping the .eval() solved the problem.

What could be the easiest way to freeze the batchnorm layers in say, layer 4 in Resnet34? I am finetuning only layer4, so plan to check both with and without freezing BN layers.

I checked resnet34.layer4.named_children() and can write loops to fetch BN layers inside layer4 but want to check if there is a more elegant way.

Encounter the same issue: the running_mean/running_var of a batchnorm layer are still being updated even though “bn.eval()”. Turns out that the only way to freeze the running_mean/running_var is “bn.track_running_stats = False” . Tried 3 settings:

  • bn.param.requires_grad = False & bn.eval()
  • bn.param.requires_grad = False & bn.track_running_stats = False & bn.eval()
  • bn.param.requires_grad = False & bn.track_running_stats = False

The first one did not work. Both 2nd and 3rd worked. So my conclusion is that use “bn.track_running_stats = False” if you wanna freeze running_mean/running_var. (Contradictory to How to freeze BN layers while training the rest of network (mean and var wont freeze) - #17 by Ahmed_m

Additional info: torch.version: 1.7.0. Seems there is a misalignment across versions.

1 Like

I tried your way and it absolutely works with setting track_running_stats=False. However, just out of curiosity, in the official document from PyTorch, it said if track_running_stats is set to False, then running_mean and running_var will be none and the batch norm will always use batch stat?
My guess is that track_running_stats will not reset running_mean and running_var, and hence what it only does is not update them anymore. Please, correct me if my assumption is wrong.

Yes, using track_running_stats=False will always normalize the input activation with the current batch stats so you would have to make sure that enough samples are passed to the model (even during testing).