How to freeze BN layers while training the pretrained model

I have a network that consists of batch normalization (BN) layers and other layers (convolution, FC, dropout, etc) which is pretrained ResNet50 model.

I want the model not to be trained so I freezed the all layer with requires_grad=False,
but I find the BN layer still updating and the performance gradually dropped.

So, I used the below code to freeze the batch norm layer.

for module in model.modules():
    # print(module)
    
   if isinstance(module, nn.BatchNorm2d):
 
        if hasattr(module, 'weight'):
            module.weight.requires_grad_(False)
        if hasattr(module, 'bias'):
            module.bias.requires_grad_(False)
        module.track_running_stats = False
        # module.eval()
  1. But I am confused about the difference between putting layer.eval() and module.track_running_stats=False.

  2. If I want the BN layer not to be updated and performance not to be changed in inference, do I just need to make it module.weight.requires_grad(False) only? Or should I also stop track_running_stats
    which is used in inference mode.

Thank you.

I just figured out if I used the code below, which do not put layer into .eval() but sets track_running_stats=False only,
I get an error.

for module in model.modules():

    # print(module)
    if isinstance(module, nn.BatchNorm2d):
        if hasattr(module, 'weight'):
            module.weight.requires_grad_(False)
        if hasattr(module, 'bias'):
            module.bias.requires_grad_(False)

        module.track_running_stats = False
        #module.eval()

I don’t know what kind of error you are getting, but changing the track_running_stats attribute after the layer creation might be dangerous. It is working in my setup, but I don’t think there is a guarantee that changing this attribute would work in previous/next versions.

bn = nn.BatchNorm2d(3)
print(bn.running_mean)
> tensor([0., 0., 0.])

out = bn(torch.randn(2, 3, 24,  24))
print(bn.running_mean)
> tensor([0.0056, 0.0054, 0.0008])

bn.track_running_stats = False
out = bn(torch.randn(2, 3, 24,  24))
print(bn.running_mean)
> tensor([0.0056, 0.0054, 0.0008])

The proper approach would be to use module.eval() or is this not working for you?

1 Like

Putting module.eval() works and does not change the performance.
But the bn layer keep changning running_mean and running_var during training even though I freezed them with module.eval().

So I thought I should also turn the track_running_stats off to prevent the running_mean and running_var which is used in inference.

That shouldn’t be the case. Could you post a minimal, executable code snippet showing this behavior, please?

1 Like

You are right. I thought the running_mean and running_var was changing
because the performance dropped.

When setting BN layer module.eval() the running_mean and running_var does not change.