How to freeze BN layers while training the rest of network (mean and var wont freeze)

The running stats are not parameters (they do not get gradients and the optimizer won’t update them), but buffers, which will be updated during the forward pass.
You could call .eval() on the batchnorm layers to apply the running stats instead of calculating and using the batch statistics and updating the running stats (which would be the case during .train()).

4 Likes

thanks for your reply. can you please give me a simple snippet code on how I can train the network while freezing the BN layer?

@Usama_Hasan’s code snippet could be used to freeze the affine parameters of all batchnorm layers, but I would rather use isinstance(child, nn.BatchNorm2d) instead of the name check.
Anyway, his code should also work. :slight_smile:

2 Likes

Yes, @ptrblck thanks for the advice, it’s a more pythonic and stable solution.

1 Like

usually, I use model.train() when training, but now I don’t wanna train the BN during the training. does it sounds right if in training I do model.eval() and then do the following


for name ,child in (model.named_children()):
    if name.find('BatchNorm') != -1:
        for param in child.parameters():
            param.requires_grad = False
    else:
        for param in child.parameters():
            param.requires_grad = True

The code looks alright, but again I would recommend to use isinstance instead of the name check. :wink:
Note that model.eval() will also disable dropout layers (any maybe change the behavior of custom layers), so you might want to call eval() only on the batchnorm layers.

3 Likes

oh I see,
Didn’t know I can call eval() just on BN layers
Will give it a try with eval on BN layers and will use isinstance
thanks for clarifying

1 Like

here is what you want to do.

You dont need to use both hasattr and eval together. But I just did it to be safe.

for module in model.modules():
    # print(module)
    if isinstance(module, nn.BatchNorm2d):
        if hasattr(module, 'weight'):
            module.weight.requires_grad_(False)
        if hasattr(module, 'bias'):
            module.bias.requires_grad_(False)
        module.eval()
7 Likes

@ptrblck I use module.eval() to freeze the stats of an instancenorm layer (with track_running_stats = True) after some epochs. However, the running_mean and ranning_var are still updating over training. Any idea why this may happen?

I cannot reproduce this issue using this code snippet and the running stats are equal after calling norm.eval():

norm = nn.InstanceNorm2d(num_features=3, track_running_stats=True)
print(norm.running_mean, norm.running_var)
> tensor([0., 0., 0.]) tensor([1., 1., 1.])

x = torch.randn(2, 3, 24, 24)

out = norm(x)
print(norm.running_mean, norm.running_var)
tensor([-0.0029,  0.0005,  0.0003]) tensor([0.9988, 1.0021, 0.9980])

out = norm(x)
print(norm.running_mean, norm.running_var)
> tensor([-0.0056,  0.0010,  0.0006]) tensor([0.9978, 1.0040, 0.9962])

norm.eval()
out = norm(x)
print(norm.running_mean, norm.running_var)
> tensor([-0.0056,  0.0010,  0.0006]) tensor([0.9978, 1.0040, 0.9962])

Are you using an older PyTorch version, where this could have been a known issue?

I use version 1.6 and the issue arise when I switch to eval() during the training. I have a situation like the following:

def _freeze_norm_stats(net):
    try:
        for m in net.modules():
           
            if isinstance(m, InstanceNorm1d):
                
                m.track_running_stats = False
                m.eval()

    except ValueError:  
        print("errrrrrrrrrrrrrroooooooorrrrrrrrrrrr with instancenorm")
        return

model = Model().cuda()

for epoch in range(start_epoch,end epoch):
    #for epochs < 10, the model updates the running_mean and running_var of its instancenorm layer
    if epoch == 10:
        print("freeze stats:")
        model.apply(_freeze_norm_stats)
    #for epochs >= 10, the model should continue with fixed fixed mean and var for its instancenorm layer

Unfortunately, printing the running_mean and running_mean after calling the function _freeze_norm_stats shows that they’re still being updated over training steps. Maybe I have something wrong in the function implementation but I am still not able to figure this out.

Could you update to the latest stable release (1.7.1) and check, if you are still facing this issue?

removing this line and only keeping the .eval() solved the problem.

What could be the easiest way to freeze the batchnorm layers in say, layer 4 in Resnet34? I am finetuning only layer4, so plan to check both with and without freezing BN layers.

I checked resnet34.layer4.named_children() and can write loops to fetch BN layers inside layer4 but want to check if there is a more elegant way.

Encounter the same issue: the running_mean/running_var of a batchnorm layer are still being updated even though “bn.eval()”. Turns out that the only way to freeze the running_mean/running_var is “bn.track_running_stats = False” . Tried 3 settings:

  • bn.param.requires_grad = False & bn.eval()
  • bn.param.requires_grad = False & bn.track_running_stats = False & bn.eval()
  • bn.param.requires_grad = False & bn.track_running_stats = False

The first one did not work. Both 2nd and 3rd worked. So my conclusion is that use “bn.track_running_stats = False” if you wanna freeze running_mean/running_var. (Contradictory to How to freeze BN layers while training the rest of network (mean and var wont freeze) - #17 by Ahmed_m

Additional info: torch.version: 1.7.0. Seems there is a misalignment across versions.

1 Like

I tried your way and it absolutely works with setting track_running_stats=False. However, just out of curiosity, in the official document from PyTorch, it said if track_running_stats is set to False, then running_mean and running_var will be none and the batch norm will always use batch stat?
My guess is that track_running_stats will not reset running_mean and running_var, and hence what it only does is not update them anymore. Please, correct me if my assumption is wrong.

Yes, using track_running_stats=False will always normalize the input activation with the current batch stats so you would have to make sure that enough samples are passed to the model (even during testing).

The documentation actually says:

track_running_stats: […] when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes.

Doesn’t that mean that when a state dict is loaded which contains running_mean and running_var, these will be used instead of the current batch statistics?

Also: How are users to go about switching between train and eval mode for training and validation in each epoch when model.train() sets all layers to training mode, even those that we initially froze by calling layer.eval()?

No, since it will directly fail:

bn1 = nn.BatchNorm2d(3, track_running_stats=False)
bn2 = nn.BatchNorm2d(3)

print(bn1.running_mean)
# None
print(bn1.running_var)
# None
bn1.load_state_dict(bn2.state_dict())
# RuntimeError: Error(s) in loading state_dict for BatchNorm2d:
# 	Unexpected key(s) in state_dict: "running_mean", "running_var", "num_batches_tracked". 
1 Like

I see! The load_state_dict documentation actually talks about this:

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

So, don’t mess with track_running_stats to freeze a BatchNorm layer, I guess.

But my last question persists: How do I sustainably freeze a BatchNorm layer so that it does not get accidentally unfrozen when the training loop switches between train and eval mode for training and validation?

My current idea would be to monkeypatch the train method:

from copy import deepcopy


encoder = nn.BatchNorm1d(4)

assert (encoder.running_mean == 0).all()
assert (encoder.running_var == 1).all()


# Train encoder
x = torch.randn((8, 4))
encoder(x)

# Training changes the running stats
assert torch.norm(encoder.running_mean) != 0
assert torch.norm(encoder.running_var) != 0

encoder_state_dict_orig = deepcopy(encoder.state_dict())

# Freeze encoder
encoder.eval()
# Prevent encoder from ever be trainable again
encoder.train = lambda mode=True: encoder

classifier = nn.Linear(4, 1)

model = nn.Sequential(encoder, classifier)

# Training loop:
model.train()
assert not encoder.training

# train...
x = torch.randn((8, 4))
output = model(x)
# (+ the usual stuff...)

model.eval()
# validate...
assert not encoder.training

x = torch.randn((8, 4))
output = model(x)

# After freezing, the running stats did not change anymore, even if the model was trained
encoder_state_dict = encoder.state_dict()
assert all(
    (encoder_state_dict[k] == v).all() for k, v in encoder_state_dict_orig.items()
)

The original train method could be saved so that it can be restored later, if needed.