Train a network by updating only part of the weights

Let’s say we have a network made of 5 layers:

input = torch.nn.Linear()
layer2 = torch.nn.Linear()
layer3 = ...
layer4 = ...
output = torch.nn.Linear(.., 2)

I want to train the network with all 5 layers, as usual, for about 10 epochs, then I want to train just the last 3 layers (layer 3, layer 4 and output). How can I tell pytorch to stop updating the weights of input and layer 1? Could someone provide a very simple, but working, example?

1 Like
for p in input.parameters():
    p.requires_grad=False
for p in layer1.parameters():
    p.requires_grad=False
5 Likes

Thanks for the reply., should I do this for every layer? For example if the above architecture is part of a model class.

for epoch in number_of_epochs:
    if epoch < 10 : 
         for p in model.input.parameters():
             p.requires_grad = True
         for q in model.layer1.parameters():
             q.requires_grad = True
   else: 
         for p in model.input.parameters():
             p.requires_grad = False
         for q in model.layer1.parameters():
             q.requires_grad = False

would this thing work?

You just need to make requires_grad=False for the both the input and the layer1 layers once at the 10th epoch.

for p in model.parameters():
    p.requires_grad = True
for epoch in number_of_epochs:
    if epoch == 10 : 
         for p in model.input.parameters():
             p.requires_grad = False
         for q in model.layer1.parameters():
             q.requires_grad = False

This should work.

6 Likes

The above method seems to have no effect whatsoever on my network. Are you sure this is the way to prevent a layer’s weights from updating?

@bixqu yes if you declare that your weights dont need grad (with requires_grad=False) then no gradient will be computed wrt the weights.