I want to train the network with all 5 layers, as usual, for about 10 epochs, then I want to train just the last 3 layers (layer 3, layer 4 and output). How can I tell pytorch to stop updating the weights of input and layer 1? Could someone provide a very simple, but working, example?
Thanks for the reply., should I do this for every layer? For example if the above architecture is part of a model class.
for epoch in number_of_epochs:
if epoch < 10 :
for p in model.input.parameters():
p.requires_grad = True
for q in model.layer1.parameters():
q.requires_grad = True
else:
for p in model.input.parameters():
p.requires_grad = False
for q in model.layer1.parameters():
q.requires_grad = False
You just need to make requires_grad=False for the both the input and the layer1 layers once at the 10th epoch.
for p in model.parameters():
p.requires_grad = True
for epoch in number_of_epochs:
if epoch == 10 :
for p in model.input.parameters():
p.requires_grad = False
for q in model.layer1.parameters():
q.requires_grad = False