The running stats are not parameters (they do not get gradients and the optimizer wonât update them), but buffers, which will be updated during the forward pass.
You could call .eval()
on the batchnorm layers to apply the running stats instead of calculating and using the batch statistics and updating the running stats (which would be the case during .train()
).
thanks for your reply. can you please give me a simple snippet code on how I can train the network while freezing the BN layer?
@Usama_Hasanâs code snippet could be used to freeze the affine parameters of all batchnorm layers, but I would rather use isinstance(child, nn.BatchNorm2d)
instead of the name check.
Anyway, his code should also work.
usually, I use model.train()
when training, but now I donât wanna train the BN during the training. does it sounds right if in training I do model.eval()
and then do the following
for name ,child in (model.named_children()):
if name.find('BatchNorm') != -1:
for param in child.parameters():
param.requires_grad = False
else:
for param in child.parameters():
param.requires_grad = True
The code looks alright, but again I would recommend to use isinstance
instead of the name check.
Note that model.eval()
will also disable dropout layers (any maybe change the behavior of custom layers), so you might want to call eval()
only on the batchnorm layers.
oh I see,
Didnât know I can call eval() just on BN layers
Will give it a try with eval on BN layers and will use isinstance
thanks for clarifying
here is what you want to do.
You dont need to use both hasattr and eval together. But I just did it to be safe.
for module in model.modules():
# print(module)
if isinstance(module, nn.BatchNorm2d):
if hasattr(module, 'weight'):
module.weight.requires_grad_(False)
if hasattr(module, 'bias'):
module.bias.requires_grad_(False)
module.eval()
@ptrblck I use module.eval() to freeze the stats of an instancenorm layer (with track_running_stats = True) after some epochs. However, the running_mean and ranning_var are still updating over training. Any idea why this may happen?
I cannot reproduce this issue using this code snippet and the running stats are equal after calling norm.eval()
:
norm = nn.InstanceNorm2d(num_features=3, track_running_stats=True)
print(norm.running_mean, norm.running_var)
> tensor([0., 0., 0.]) tensor([1., 1., 1.])
x = torch.randn(2, 3, 24, 24)
out = norm(x)
print(norm.running_mean, norm.running_var)
tensor([-0.0029, 0.0005, 0.0003]) tensor([0.9988, 1.0021, 0.9980])
out = norm(x)
print(norm.running_mean, norm.running_var)
> tensor([-0.0056, 0.0010, 0.0006]) tensor([0.9978, 1.0040, 0.9962])
norm.eval()
out = norm(x)
print(norm.running_mean, norm.running_var)
> tensor([-0.0056, 0.0010, 0.0006]) tensor([0.9978, 1.0040, 0.9962])
Are you using an older PyTorch version, where this could have been a known issue?
I use version 1.6 and the issue arise when I switch to eval() during the training. I have a situation like the following:
def _freeze_norm_stats(net):
try:
for m in net.modules():
if isinstance(m, InstanceNorm1d):
m.track_running_stats = False
m.eval()
except ValueError:
print("errrrrrrrrrrrrrroooooooorrrrrrrrrrrr with instancenorm")
return
model = Model().cuda()
for epoch in range(start_epoch,end epoch):
#for epochs < 10, the model updates the running_mean and running_var of its instancenorm layer
if epoch == 10:
print("freeze stats:")
model.apply(_freeze_norm_stats)
#for epochs >= 10, the model should continue with fixed fixed mean and var for its instancenorm layer
Unfortunately, printing the running_mean and running_mean after calling the function _freeze_norm_stats shows that theyâre still being updated over training steps. Maybe I have something wrong in the function implementation but I am still not able to figure this out.
Could you update to the latest stable release (1.7.1
) and check, if you are still facing this issue?
removing this line and only keeping the .eval() solved the problem.
What could be the easiest way to freeze the batchnorm layers in say, layer 4 in Resnet34? I am finetuning only layer4, so plan to check both with and without freezing BN layers.
I checked resnet34.layer4.named_children() and can write loops to fetch BN layers inside layer4 but want to check if there is a more elegant way.
Encounter the same issue: the running_mean/running_var of a batchnorm layer are still being updated even though âbn.eval()â. Turns out that the only way to freeze the running_mean/running_var is âbn.track_running_stats = Falseâ . Tried 3 settings:
- bn.param.requires_grad = False & bn.eval()
- bn.param.requires_grad = False & bn.track_running_stats = False & bn.eval()
- bn.param.requires_grad = False & bn.track_running_stats = False
The first one did not work. Both 2nd and 3rd worked. So my conclusion is that use âbn.track_running_stats = Falseâ if you wanna freeze running_mean/running_var. (Contradictory to How to freeze BN layers while training the rest of network (mean and var wont freeze) - #17 by Ahmed_m
Additional info: torch.version: 1.7.0. Seems there is a misalignment across versions.
I tried your way and it absolutely works with setting track_running_stats=False. However, just out of curiosity, in the official document from PyTorch, it said if track_running_stats is set to False, then running_mean and running_var will be none and the batch norm will always use batch stat?
My guess is that track_running_stats will not reset running_mean and running_var, and hence what it only does is not update them anymore. Please, correct me if my assumption is wrong.
Yes, using track_running_stats=False
will always normalize the input activation with the current batch stats so you would have to make sure that enough samples are passed to the model (even during testing).
The documentation actually says:
track_running_stats: [âŠ] when set to
False
, this module does not track such statistics, and initializes statistics buffersrunning_mean
andrunning_var
asNone
. When these buffers areNone
, this module always uses batch statistics. in both training and eval modes.
Doesnât that mean that when a state dict is loaded which contains running_mean
and running_var
, these will be used instead of the current batch statistics?
Also: How are users to go about switching between train and eval mode for training and validation in each epoch when model.train()
sets all layers to training mode, even those that we initially froze by calling layer.eval()
?
No, since it will directly fail:
bn1 = nn.BatchNorm2d(3, track_running_stats=False)
bn2 = nn.BatchNorm2d(3)
print(bn1.running_mean)
# None
print(bn1.running_var)
# None
bn1.load_state_dict(bn2.state_dict())
# RuntimeError: Error(s) in loading state_dict for BatchNorm2d:
# Unexpected key(s) in state_dict: "running_mean", "running_var", "num_batches_tracked".
I see! The load_state_dict documentation actually talks about this:
If a parameter or buffer is registered as
None
and its corresponding key exists instate_dict
,load_state_dict()
will raise aRuntimeError
.
So, donât mess with track_running_stats
to freeze a BatchNorm layer, I guess.
But my last question persists: How do I sustainably freeze a BatchNorm layer so that it does not get accidentally unfrozen when the training loop switches between train and eval mode for training and validation?
My current idea would be to monkeypatch the train
method:
from copy import deepcopy
encoder = nn.BatchNorm1d(4)
assert (encoder.running_mean == 0).all()
assert (encoder.running_var == 1).all()
# Train encoder
x = torch.randn((8, 4))
encoder(x)
# Training changes the running stats
assert torch.norm(encoder.running_mean) != 0
assert torch.norm(encoder.running_var) != 0
encoder_state_dict_orig = deepcopy(encoder.state_dict())
# Freeze encoder
encoder.eval()
# Prevent encoder from ever be trainable again
encoder.train = lambda mode=True: encoder
classifier = nn.Linear(4, 1)
model = nn.Sequential(encoder, classifier)
# Training loop:
model.train()
assert not encoder.training
# train...
x = torch.randn((8, 4))
output = model(x)
# (+ the usual stuff...)
model.eval()
# validate...
assert not encoder.training
x = torch.randn((8, 4))
output = model(x)
# After freezing, the running stats did not change anymore, even if the model was trained
encoder_state_dict = encoder.state_dict()
assert all(
(encoder_state_dict[k] == v).all() for k, v in encoder_state_dict_orig.items()
)
The original train
method could be saved so that it can be restored later, if needed.