How to freeze BN layers while training the rest of network (mean and var wont freeze)

@Usama_Hasan’s code snippet could be used to freeze the affine parameters of all batchnorm layers, but I would rather use isinstance(child, nn.BatchNorm2d) instead of the name check.
Anyway, his code should also work. :slight_smile:

2 Likes

Yes, @ptrblck thanks for the advice, it’s a more pythonic and stable solution.

1 Like

usually, I use model.train() when training, but now I don’t wanna train the BN during the training. does it sounds right if in training I do model.eval() and then do the following


for name ,child in (model.named_children()):
    if name.find('BatchNorm') != -1:
        for param in child.parameters():
            param.requires_grad = False
    else:
        for param in child.parameters():
            param.requires_grad = True

The code looks alright, but again I would recommend to use isinstance instead of the name check. :wink:
Note that model.eval() will also disable dropout layers (any maybe change the behavior of custom layers), so you might want to call eval() only on the batchnorm layers.

3 Likes

oh I see,
Didn’t know I can call eval() just on BN layers
Will give it a try with eval on BN layers and will use isinstance
thanks for clarifying

1 Like

here is what you want to do.

You dont need to use both hasattr and eval together. But I just did it to be safe.

for module in model.modules():
    # print(module)
    if isinstance(module, nn.BatchNorm2d):
        if hasattr(module, 'weight'):
            module.weight.requires_grad_(False)
        if hasattr(module, 'bias'):
            module.bias.requires_grad_(False)
        module.eval()
7 Likes

@ptrblck I use module.eval() to freeze the stats of an instancenorm layer (with track_running_stats = True) after some epochs. However, the running_mean and ranning_var are still updating over training. Any idea why this may happen?

I cannot reproduce this issue using this code snippet and the running stats are equal after calling norm.eval():

norm = nn.InstanceNorm2d(num_features=3, track_running_stats=True)
print(norm.running_mean, norm.running_var)
> tensor([0., 0., 0.]) tensor([1., 1., 1.])

x = torch.randn(2, 3, 24, 24)

out = norm(x)
print(norm.running_mean, norm.running_var)
tensor([-0.0029,  0.0005,  0.0003]) tensor([0.9988, 1.0021, 0.9980])

out = norm(x)
print(norm.running_mean, norm.running_var)
> tensor([-0.0056,  0.0010,  0.0006]) tensor([0.9978, 1.0040, 0.9962])

norm.eval()
out = norm(x)
print(norm.running_mean, norm.running_var)
> tensor([-0.0056,  0.0010,  0.0006]) tensor([0.9978, 1.0040, 0.9962])

Are you using an older PyTorch version, where this could have been a known issue?

I use version 1.6 and the issue arise when I switch to eval() during the training. I have a situation like the following:

def _freeze_norm_stats(net):
    try:
        for m in net.modules():
           
            if isinstance(m, InstanceNorm1d):
                
                m.track_running_stats = False
                m.eval()

    except ValueError:  
        print("errrrrrrrrrrrrrroooooooorrrrrrrrrrrr with instancenorm")
        return

model = Model().cuda()

for epoch in range(start_epoch,end epoch):
    #for epochs < 10, the model updates the running_mean and running_var of its instancenorm layer
    if epoch == 10:
        print("freeze stats:")
        model.apply(_freeze_norm_stats)
    #for epochs >= 10, the model should continue with fixed fixed mean and var for its instancenorm layer

Unfortunately, printing the running_mean and running_mean after calling the function _freeze_norm_stats shows that they’re still being updated over training steps. Maybe I have something wrong in the function implementation but I am still not able to figure this out.

Could you update to the latest stable release (1.7.1) and check, if you are still facing this issue?

removing this line and only keeping the .eval() solved the problem.

What could be the easiest way to freeze the batchnorm layers in say, layer 4 in Resnet34? I am finetuning only layer4, so plan to check both with and without freezing BN layers.

I checked resnet34.layer4.named_children() and can write loops to fetch BN layers inside layer4 but want to check if there is a more elegant way.

Encounter the same issue: the running_mean/running_var of a batchnorm layer are still being updated even though “bn.eval()”. Turns out that the only way to freeze the running_mean/running_var is “bn.track_running_stats = False” . Tried 3 settings:

  • bn.param.requires_grad = False & bn.eval()
  • bn.param.requires_grad = False & bn.track_running_stats = False & bn.eval()
  • bn.param.requires_grad = False & bn.track_running_stats = False

The first one did not work. Both 2nd and 3rd worked. So my conclusion is that use “bn.track_running_stats = False” if you wanna freeze running_mean/running_var. (Contradictory to How to freeze BN layers while training the rest of network (mean and var wont freeze) - #17 by Ahmed_m

Additional info: torch.version: 1.7.0. Seems there is a misalignment across versions.

1 Like

I tried your way and it absolutely works with setting track_running_stats=False. However, just out of curiosity, in the official document from PyTorch, it said if track_running_stats is set to False, then running_mean and running_var will be none and the batch norm will always use batch stat?
My guess is that track_running_stats will not reset running_mean and running_var, and hence what it only does is not update them anymore. Please, correct me if my assumption is wrong.

Yes, using track_running_stats=False will always normalize the input activation with the current batch stats so you would have to make sure that enough samples are passed to the model (even during testing).

The documentation actually says:

track_running_stats: […] when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes.

Doesn’t that mean that when a state dict is loaded which contains running_mean and running_var, these will be used instead of the current batch statistics?

Also: How are users to go about switching between train and eval mode for training and validation in each epoch when model.train() sets all layers to training mode, even those that we initially froze by calling layer.eval()?

No, since it will directly fail:

bn1 = nn.BatchNorm2d(3, track_running_stats=False)
bn2 = nn.BatchNorm2d(3)

print(bn1.running_mean)
# None
print(bn1.running_var)
# None
bn1.load_state_dict(bn2.state_dict())
# RuntimeError: Error(s) in loading state_dict for BatchNorm2d:
# 	Unexpected key(s) in state_dict: "running_mean", "running_var", "num_batches_tracked". 
1 Like

I see! The load_state_dict documentation actually talks about this:

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

So, don’t mess with track_running_stats to freeze a BatchNorm layer, I guess.

But my last question persists: How do I sustainably freeze a BatchNorm layer so that it does not get accidentally unfrozen when the training loop switches between train and eval mode for training and validation?

My current idea would be to monkeypatch the train method:

from copy import deepcopy


encoder = nn.BatchNorm1d(4)

assert (encoder.running_mean == 0).all()
assert (encoder.running_var == 1).all()


# Train encoder
x = torch.randn((8, 4))
encoder(x)

# Training changes the running stats
assert torch.norm(encoder.running_mean) != 0
assert torch.norm(encoder.running_var) != 0

encoder_state_dict_orig = deepcopy(encoder.state_dict())

# Freeze encoder
encoder.eval()
# Prevent encoder from ever be trainable again
encoder.train = lambda mode=True: encoder

classifier = nn.Linear(4, 1)

model = nn.Sequential(encoder, classifier)

# Training loop:
model.train()
assert not encoder.training

# train...
x = torch.randn((8, 4))
output = model(x)
# (+ the usual stuff...)

model.eval()
# validate...
assert not encoder.training

x = torch.randn((8, 4))
output = model(x)

# After freezing, the running stats did not change anymore, even if the model was trained
encoder_state_dict = encoder.state_dict()
assert all(
    (encoder_state_dict[k] == v).all() for k, v in encoder_state_dict_orig.items()
)

The original train method could be saved so that it can be restored later, if needed.

Monkey-patching could work, but as you know it could also easily break.
The recommended way might be to make sure to call model.batchnormlayer.eval() separately after each model.train() call, but that’s of course not meeting your wish to not care about it.

1 Like

Thanks!

While this works of cause, I think this violates the “separation of concerns” principle. (Why should the training loop know about some special layers?)

Do you think that a pull request would find supporters that introduces Model.freeze() in addition to Model.train/eval() and ensures that frozen layers remain frozen even when calling Model.train() on a parent? I think freezing some parts of a model is a very common practice and this forum is full of questions on how to do it properly. (And I believe many people will miss some delicate subtleties, such as remembering to eval() these parts explicitly in the training loop.)