I have a network that consists of batch normalization (BN) layers and other layers (convolution, FC, dropout, etc)
I was wondering how we can do the following :
I want to freeze all the layer and just train the BN layers
freeze the BN layers and train every other layer in the network except BN layers
My main issue is how to handle freezing and training the BN layers
You can do something like this, hope this works for you.
model = Net()
for name ,child in (model.named_children()):
if name.find('BatchNorm') != -1:
for param in child.parameters():
param.requires_grad = True
else:
for param in child.parameters():
param.requires_grad = False
The running stats are not parameters (they do not get gradients and the optimizer won’t update them), but buffers, which will be updated during the forward pass.
You could call .eval() on the batchnorm layers to apply the running stats instead of calculating and using the batch statistics and updating the running stats (which would be the case during .train()).
@Usama_Hasan’s code snippet could be used to freeze the affine parameters of all batchnorm layers, but I would rather use isinstance(child, nn.BatchNorm2d) instead of the name check.
Anyway, his code should also work.
usually, I use model.train() when training, but now I don’t wanna train the BN during the training. does it sounds right if in training I do model.eval() and then do the following
for name ,child in (model.named_children()):
if name.find('BatchNorm') != -1:
for param in child.parameters():
param.requires_grad = False
else:
for param in child.parameters():
param.requires_grad = True
The code looks alright, but again I would recommend to use isinstance instead of the name check.
Note that model.eval() will also disable dropout layers (any maybe change the behavior of custom layers), so you might want to call eval() only on the batchnorm layers.
You dont need to use both hasattr and eval together. But I just did it to be safe.
for module in model.modules():
# print(module)
if isinstance(module, nn.BatchNorm2d):
if hasattr(module, 'weight'):
module.weight.requires_grad_(False)
if hasattr(module, 'bias'):
module.bias.requires_grad_(False)
module.eval()
@ptrblck I use module.eval() to freeze the stats of an instancenorm layer (with track_running_stats = True) after some epochs. However, the running_mean and ranning_var are still updating over training. Any idea why this may happen?
I use version 1.6 and the issue arise when I switch to eval() during the training. I have a situation like the following:
def _freeze_norm_stats(net):
try:
for m in net.modules():
if isinstance(m, InstanceNorm1d):
m.track_running_stats = False
m.eval()
except ValueError:
print("errrrrrrrrrrrrrroooooooorrrrrrrrrrrr with instancenorm")
return
model = Model().cuda()
for epoch in range(start_epoch,end epoch):
#for epochs < 10, the model updates the running_mean and running_var of its instancenorm layer
if epoch == 10:
print("freeze stats:")
model.apply(_freeze_norm_stats)
#for epochs >= 10, the model should continue with fixed fixed mean and var for its instancenorm layer
Unfortunately, printing the running_mean and running_mean after calling the function _freeze_norm_stats shows that they’re still being updated over training steps. Maybe I have something wrong in the function implementation but I am still not able to figure this out.
What could be the easiest way to freeze the batchnorm layers in say, layer 4 in Resnet34? I am finetuning only layer4, so plan to check both with and without freezing BN layers.
I checked resnet34.layer4.named_children() and can write loops to fetch BN layers inside layer4 but want to check if there is a more elegant way.
Encounter the same issue: the running_mean/running_var of a batchnorm layer are still being updated even though “bn.eval()”. Turns out that the only way to freeze the running_mean/running_var is “bn.track_running_stats = False” . Tried 3 settings:
I tried your way and it absolutely works with setting track_running_stats=False. However, just out of curiosity, in the official document from PyTorch, it said if track_running_stats is set to False, then running_mean and running_var will be none and the batch norm will always use batch stat?
My guess is that track_running_stats will not reset running_mean and running_var, and hence what it only does is not update them anymore. Please, correct me if my assumption is wrong.
Yes, using track_running_stats=False will always normalize the input activation with the current batch stats so you would have to make sure that enough samples are passed to the model (even during testing).