Update only a middle layer of a neural network

Hi,

Say I have a network, like the figure below.
During the training, I want to freeze layer 1 and 3 and only update weights on layer 2. Is it possible?

I don’t know how to do because according to the Autograd documentation, if at least one the inputs of a node in the computation requires grad, then that node requires grad in the graph (which makes sense!).

But I need to replace only a middle layer of the neural network and don’t touch the rest.

Untitled%20Diagram%20(4)

Thanks.

1 Like

Hi,

You can go through every parameters that you don’t want to update and set them not to require gradients:

for p in layer2.parameters():
    p.requires_grad = False

This means that no gradients will be computed for these parameters.
As further optimization, you can give to your optimizer only the layer you want to optimize such that it won’t check every other parameters that won’t have gradients:

my_opt = torch.optim.SGD(model.layer2.parameters(), other_args)
2 Likes

Thanks for your answer.
But, in order to calculate layer2 gradients, I think we need to calculate layer3 gradients. If I set layer3 grads to False, how it would be possible?

You will set the layer3’s parameters gradients to False. It means that they won’t ask for gradients. But the input to layer3, coming from layer2 will have requires_grad=True because these gradients are needed for layer2’s parameters. And so the output of layer3 will require gradients but during the backward pass, only the gradients wrt to input of layer3 will be computed, not wrt its parameters.

If you were looking at layer2, its input does not require gradients, but its parameters does. And so its output will as well. During the backward pass though, only gradients wrt the parameters will be computed and not wrt the input.

2 Likes

Another thing regarding this question:
What if I have batchnorm in layer3 ?
Should I manually call .eval() for layer3?

.eval() has nothing to do with gradient computation but with how layer behaves. Batcnorm will use saved statistics in eval mode, dropout will be an identity etc.