How to freeze BN layers while training the rest of network (mean and var wont freeze)

The documentation actually says:

track_running_stats: […] when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes.

Doesn’t that mean that when a state dict is loaded which contains running_mean and running_var, these will be used instead of the current batch statistics?

Also: How are users to go about switching between train and eval mode for training and validation in each epoch when model.train() sets all layers to training mode, even those that we initially froze by calling layer.eval()?

No, since it will directly fail:

bn1 = nn.BatchNorm2d(3, track_running_stats=False)
bn2 = nn.BatchNorm2d(3)

print(bn1.running_mean)
# None
print(bn1.running_var)
# None
bn1.load_state_dict(bn2.state_dict())
# RuntimeError: Error(s) in loading state_dict for BatchNorm2d:
# 	Unexpected key(s) in state_dict: "running_mean", "running_var", "num_batches_tracked". 
1 Like

I see! The load_state_dict documentation actually talks about this:

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

So, don’t mess with track_running_stats to freeze a BatchNorm layer, I guess.

But my last question persists: How do I sustainably freeze a BatchNorm layer so that it does not get accidentally unfrozen when the training loop switches between train and eval mode for training and validation?

My current idea would be to monkeypatch the train method:

from copy import deepcopy


encoder = nn.BatchNorm1d(4)

assert (encoder.running_mean == 0).all()
assert (encoder.running_var == 1).all()


# Train encoder
x = torch.randn((8, 4))
encoder(x)

# Training changes the running stats
assert torch.norm(encoder.running_mean) != 0
assert torch.norm(encoder.running_var) != 0

encoder_state_dict_orig = deepcopy(encoder.state_dict())

# Freeze encoder
encoder.eval()
# Prevent encoder from ever be trainable again
encoder.train = lambda mode=True: encoder

classifier = nn.Linear(4, 1)

model = nn.Sequential(encoder, classifier)

# Training loop:
model.train()
assert not encoder.training

# train...
x = torch.randn((8, 4))
output = model(x)
# (+ the usual stuff...)

model.eval()
# validate...
assert not encoder.training

x = torch.randn((8, 4))
output = model(x)

# After freezing, the running stats did not change anymore, even if the model was trained
encoder_state_dict = encoder.state_dict()
assert all(
    (encoder_state_dict[k] == v).all() for k, v in encoder_state_dict_orig.items()
)

The original train method could be saved so that it can be restored later, if needed.

Monkey-patching could work, but as you know it could also easily break.
The recommended way might be to make sure to call model.batchnormlayer.eval() separately after each model.train() call, but that’s of course not meeting your wish to not care about it.

1 Like

Thanks!

While this works of cause, I think this violates the “separation of concerns” principle. (Why should the training loop know about some special layers?)

Do you think that a pull request would find supporters that introduces Model.freeze() in addition to Model.train/eval() and ensures that frozen layers remain frozen even when calling Model.train() on a parent? I think freezing some parts of a model is a very common practice and this forum is full of questions on how to do it properly. (And I believe many people will miss some delicate subtleties, such as remembering to eval() these parts explicitly in the training loop.)