Proper way of freezing BatchNorm running statistics

Hi everybody,

What I want to do is to use a pretrained network that contains batch normalization layers and perform finetuning. So I want to freeze the weights of the network. Apart from freezing the weight and bias of batch norm, I would like also to freeze the running_mean and running_std and use the values from the pretrained network. I 've seen many posts that address this issue, but there is not clear answer. The general answer is to put the batchnorm layers in eval mode. But people report that if you first put your whole model in train mode and after that only the batchnorm layers in eval mode, training is not converging. Another post suggests to override the train() function by putting the batchnorm layers in eval mode inside train(). Is this the way to go? Does anyone have to suggest a better way? Thank you.

3 Likes

Hey, I’m reporting exactly what you said there almost 3 years ago. My network won’t converge after I call module.eval() for all my batchnorm2d modules. But I still do need to find a way to freeze the running stats of my batchnorm2d layers. I wonder if you have found a way?

Hi,

You can freeze BN trainable parameters with something like:

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.weight.requires_grad_(False)
        m.bias.requires_grad_(False)

This should be done once before training starts. Also don’t forget to pass to the optimiser ONLY the trainable params.

And then you can freeze BN statistics at the beginning of each epoch (after you call model.train()) with:

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.eval()

Hope this helps!

First of all, thanks for your immediate reply.
Regarding calling model.eval(), I am doing exactly what you suggest here.
Regarding configuring the trainable parameters so that they won’t require gradients, I found that this was actually the exact cause for gradient explosion in my case (I’m training with half precision a Resnet-based architecture for 3D human pose estimation). After I switch the code back so that the gradients for bn.weight and bn.bias are updated as usual (but the bn layers still function under eval mode at training time), the gradient explosion stopped. And for my own experiments this is the ideal behaviour (at training time the trainable params of bn layers are updated as usual but the runnning stats are fixed)