Retrain BatchNorm layer only

I have pretrain model like resnet50, I want to retrain for new data
in keras, I do following, how to do it in pytorch:
for layer in base_model.layers:
if not isinstance(layer, layers.BatchNormalization): layer.trainable = False
How to iterator every layers in pytorch model to find out nn.BatchNorm2d?

You could filter out the batch norm layers and set the requires_grad attribute of all other parameters to False:

model = models.resnet18()

def freeze_all_but_bn(m):
    if not isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
        if hasattr(m, 'weight') and m.weight is not None:
            m.weight.requires_grad_(False)
        if hasattr(m, 'bias') and m.bias is not None:
            m.bias.requires_grad_(False)
              
model.apply(freeze_all_but_bn)
print(model.fc.weight.requires_grad)
> False
print(model.layer4[1].bn2.weight.requires_grad)
> True

This code should make sure, that no BatchNorm layer will accidentally be manipulated in the recursive function.

1 Like

Hello,

Is there any way for using apply method for only train the last 10 layers?
My current solution is to reversely walk through each layer and set the gradient=True, but I want to see if there is any more clear way to do that? The problem with my own solution is that I have to know the name of blocks of each model fo rexample Resnet has models.resnet.BasicBlock.

I don’t think there is a universal approach to get the “last 10 layers”, as it would depend on the model implementation and in particular the module creation.
E.g. if you could make sure the layers are initialized in a sequential manner, you could try to use a count to filter out the desired modules. However, if the initialization is unordered, you could easily run into errors as seen here:

def weight_init(m):
    if isinstance(m, nn.Linear):
        print(m.weight.shape[0])


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(1, 1)
        self.fc2 = nn.Linear(1, 2)
        self.fc3 = nn.Linear(2, 3)
        self.fc4 = nn.Linear(3, 4)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        return x

model = MyModel()
model.apply(weight_init)
> 1
  2
  3
  4

class MyModelUnord(nn.Module):
    def __init__(self):
        super(MyModelUnord, self).__init__()
        self.fc4 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(1, 2)
        self.fc1 = nn.Linear(1, 1)
        self.fc3 = nn.Linear(2, 3)
        
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        return x

model = MyModelUnord()
model.apply(weight_init)
> 4
  2
  1
  3

If you are writing a custom modules, you could create submodules, which define a “unit” and could be frozen in a single operation. On the other hand, if you are working with model definitions created by other users, I would recommend to apply a more manual approach and make sure you are freezing the desired modules.