Should I use model.eval() when I freeze BatchNorm layers to finetune?

Hi,
I have a well trained coarse net (including BN layers) which I want to freeze to finetune other layers added. I filtered out the parameters of the coarse net when construct optimizer. Is model.eval() or something else necessary in this case?I don’t want the BN layers to recalculate the mean and variance in every batch.

If you set your nn.BatchNorm layers to eval() the running estimates won’t be updated anymore.
Additionally to filtering out the parameters, you could also set the .requires_grad attribute to False, so that the gradients won’t be computed if not necessary.

1 Like

Thanks!! But I am still a little confused.

for p in model.parameters():
    p.requires_grad = False
for p in model.fine.parameters():
    p.requires_grad = True
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 0.001)

This is my code for finetuning, and I didn’t use model.eval(). Do the BN layers in my model behave the same as BN layers in other models whose eval() are true?
If it’s not same, what should I do to freeze the BN layers (make BN layers use global means and variances instead of them of every mini-batch)?
I am looking forward to your reply.

If you only want to fine-tune the parameters in model.fine you could do the following:

model = ...
optimizer = torch.optim.Adam(model.fine.parameters(),lr = 0.001)

Now the optimizer will only try to update parameters within the .fine submodule.

1 Like

It depends if they were set to .eval() before, but the default mode is train() after loading the model.
If you want to set the complete model to eval mode, just use model.eval().
Alternatively, if you just want to apply it on all batch norm layers, you could use:

def set_bn_eval(module):
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module.eval()
        
model.apply(set_bn_eval)
9 Likes

Hi @ptrblck a little more detail on this - does setting the batch norm layer to eval allow us to train ‘gamma’ and beta parameters? I understand that the eval operation allows us to use the current batch’s mean and variance when fine tuning a pretrained model.

The .train() and .eval() call on batchnorm layers does not freeze the affine parameters, so that the gamma (weight) and beta (bias) parameters can still be trained.

Calling eval() on batchnorm layers will use the running stats, while train() will use the batch stats and update the running stats.

2 Likes

Thanks, that clarification was useful.

Hi @ptrblck - one follow up on your response earlier. Does setting requires_grad = False sufficient, even for batch norm layers? Or do we have to do both, i.e., set requires_grad = False AND BN_layer.eval()?

The requires_grad attribute and calling train()/eval() on it behave differently.
BatchNorm layers use trainable affine parameters by default, which are assigned to the .weight and .bias attribute. These parameters use .requires_grad = True by default and you can freeze them by setting this attribute to False.
During training (i.e. after calling model.train() or after creating the model) the batchnorm layer will normalize the input activation using the batch stats and will update the internal stats using a running average.
After calling model.eval() the batchnorm layers will use the trained internal running stats (stored as .running_mean and .running_var) to normalize the input activation.

2 Likes

Got it, thanks so much for your detailed response! So, in the event that I set just requires_grad = False for BN layers, it may still be computing the running average during training phase, which is not ideal. I should be doing both, so the batchnorm layer uses the stored .running_mean and .running_var values for normalization.

1 Like

If you want to validate your model, wrapping the forward pass into with torch.no_grad() or with torch.inference_mode() and calling model.eval() would also work. You wouldn’t necessarily need to flip the .requires_grad attribute (it would also work, but the former guards might be a simpler way).

2 Likes

If I want to fine-tune a model with BN layers and only freeze the BN layers, we should set model.eval() and BN layers’ requires_grad to False. Am I right?

Yes, if you want to freeze the affine parameters (weight and bias), set their .requires_grad attribute to False). If you additionally want to use the running stats and not update them with the current batch stats, call .eval() on the batchnorm layers.