Unfreeze BatchNorm only in a deep net

Hi,

How can I unfreeze only BatchNorm2D and freeze all the other layers is a deep net like ResNet50?

Thank you.

You could first freeze all parameters and later unfreeze the parameters of all batchnorm layers:

# Freeze all parameters
model = models.resnet50()
for param in model.parameters():
    param.requires_grad_(False)

# Unfreeze bn params
for module in model.modules():
    if isinstance(module, nn.BatchNorm2d):
        if hasattr(module, 'weight'):
            module.weight.requires_grad_(True)
        if hasattr(module, 'bias'):
            module.bias.requires_grad_(True)

# Check
out = model(torch.randn(1, 3, 224, 224))
out.mean().backward()

for name, param in model.named_parameters():
    if param.grad is not None:
        print(name, param.grad.abs().sum())
1 Like