How to freeze the part of the model?

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.network1 = nn.Sequential(
            nn.Linear(10,100),
            nn.Dropout(0.2),
            nn.ReLU(),
        )
        self.network2 = nn.Sequential(
            nn.Linear(100,100),
            nn.Dropout(0.2),
            nn.ReLU(),
        )        

    def forward(self, x):
        x = self.network1(x)
        x = self.network2(x)
        return x

I want to freeze network2 in Network(). I don’t know how to freeze.
Let me guess.
First, train the whole model .
Second, freeze the network2 and fine-tuning the network1.
Q1) This flow is right?

Q2) How can I freeze the network 2 ?
(If above flow is right) After train the whole network, just change requires_grad from True to False? Is that all I have to do?

for p in network.parameters():
    p.requires_grad = False

When you set the requires_grad=False, the parameters won’t be updated during backward pass.

You can easily freeze all the network2 parameters via:

def freeze_network2(model):
    for name, p in model.named_parameters():
        if "network2" in name:
            p.requires_grad = False
6 Likes

From your description I’m not entirely sure why you would move to finetuning (usually you do this if you train the network on a large dataset and then transfer to a smaller target dataset).
But more to your question, I would recommend to create a new optimizer (or have two before) because many optimizers have a momentum term that may cause changes in parameters even when the gradients are zero.
Other than that, your general approach should be OK.

Best regards

Thomas

1 Like

thanks for the reply!

network2.weight.requires_grad = False
network2.bias.requires_grad = False

those things are not needed?

Based on your need.

If you want to keep some of the parameters fixed during training,
you can just ignore those parameters when you passing parameters to a optimizer:

trainable_parameters = []
for name, p in model.named_parameters():
    if "network2" not in name:
        trainable_parameters.append(p)

optimizer = torch.optim.SGD(params=trainable_parameters, lr=0.1, momentum=1e-5)