Hi,
How can I unfreeze only BatchNorm2D and freeze all the other layers is a deep net like ResNet50?
Thank you.
Hi,
How can I unfreeze only BatchNorm2D and freeze all the other layers is a deep net like ResNet50?
Thank you.
You could first freeze all parameters and later unfreeze the parameters of all batchnorm layers:
# Freeze all parameters
model = models.resnet50()
for param in model.parameters():
param.requires_grad_(False)
# Unfreeze bn params
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
if hasattr(module, 'weight'):
module.weight.requires_grad_(True)
if hasattr(module, 'bias'):
module.bias.requires_grad_(True)
# Check
out = model(torch.randn(1, 3, 224, 224))
out.mean().backward()
for name, param in model.named_parameters():
if param.grad is not None:
print(name, param.grad.abs().sum())