How to freeze the part of the model?


(J Na) #1
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.network1 = nn.Sequential(
            nn.Linear(10,100),
            nn.Dropout(0.2),
            nn.ReLU(),
        )
        self.network2 = nn.Sequential(
            nn.Linear(100,100),
            nn.Dropout(0.2),
            nn.ReLU(),
        )        

    def forward(self, x):
        x = self.network1(x)
        x = self.network2(x)
        return x

I want to freeze network2 in Network(). I don’t know how to freeze.
Let me guess.
First, train the whole model .
Second, freeze the network2 and fine-tuning the network1.
Q1) This flow is right?

Q2) How can I freeze the network 2 ?
(If above flow is right) After train the whole network, just change requires_grad from True to False? Is that all I have to do?

for p in network.parameters():
    p.requires_grad = False

(KAI ZHAO) #2

When you set the requires_grad=False, the parameters won’t be updated during backward pass.

You can easily freeze all the network2 parameters via:

def freeze_network2(model):
    for name, p in model.named_parameters():
        if "network2" in name:
            p.requires_grad = False

(Thomas V) #3

From your description I’m not entirely sure why you would move to finetuning (usually you do this if you train the network on a large dataset and then transfer to a smaller target dataset).
But more to your question, I would recommend to create a new optimizer (or have two before) because many optimizers have a momentum term that may cause changes in parameters even when the gradients are zero.
Other than that, your general approach should be OK.

Best regards

Thomas


(J Na) #4

thanks for the reply!

network2.weight.requires_grad = False
network2.bias.requires_grad = False

those things are not needed?


(KAI ZHAO) #5

Based on your need.

If you want to keep some of the parameters fixed during training,
you can just ignore those parameters when you passing parameters to a optimizer:

trainable_parameters = []
for name, p in model.named_parameters():
    if "network2" not in name:
        trainable_parameters.append(p)

optimizer = torch.optim.SGD(params=trainable_parameters, lr=0.1, momentum=1e-5)