Is it necessary to set BatchNorm to eval mode when we finetune the last layer of a pretrained model?

Suppose that I want to use a pretrained ResNet18 to do classification on a small dataset, I just want to train the last (linear) classification layer while freezing all previous layers (referred to as “freezed” layers). This is a fairly common use case. To do so, the official PyTorch tutorial says that it is sufficient to do:

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

...

# training phase
model.train()
for inputs, labels in dataloaders:
    optimizer.zero_grad()
    outputs = model(inputs)
    ...
    loss.backward()
    optimizer.step()
    ...

The tutorial says that it is sufficient to set requires_grad to False for “freezed” layers, it did not call .eval() for BatchNorm2d of “freezed” layers of resnet18. Does this mean that even if convolutional layers and affine parameters of BatchNorm2d of “freezed” layers are not learning, the running average of BatchNorm in “freezed” layers are still being updated even if we only wanted to train the last linear layer? Is this a mistake?

According to the following posts and documentation, it seems that in addition to set requires_grad to False for “freezed” layers (convolutional layers and BatchNorm layers), we should also call .eval() on all BatchNorm layers if we only want to train the last linear layer while freezing all “freezed” layers, which is contradicting the official PyTorch tutorial on finetuing.

def set_bn_eval(module):
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        module.eval()
        
model.apply(set_bn_eval)

Do we indeed need to call .eval() on BatchNorm layers when we only train the last linear layer of a pretrained resnet18 model?

Yes, the running stats would still be update in training mode.

I don’t know if it’s a mistake and would guess it could help in fine tuning the model especially if your new input data has different stats.