Understanding of requires_grad = False

When you wish to not update (freeze) parts of the network, the recommended solution is to set requires_grad = False, and/or (please confirm?) not send the parameters you wish to freeze to the optimizer input.

I would like to clarify that the requires_grad = False simply avoids unnecessary computation, update, and storage of gradients at those nodes and does not create subgraphs which saves memory.

However, the parameters with requires_grad = False will still contain a grad_fn (if it has one) so that in the backward pass, the gradient from that node is still technically still calculated passed backward and the chain rule is still maintained and used for the parameters with requires_grad = True?

So if we have layers L1 -> L2 -> LOSS and want to freeze L2, the gradient of LOSS wrt L1 requires the gradient of LOSS wrt L2 as per chain rule. My confusion is that as the name implies, requires_grad = False on L2 will not compute the gradient, but we need it to update L1, no?

Please confirm :slight_smile: Thank you


Hi, @likethevegetable

I tired a snippet as follows:

random_input = torch.randn(3,3)
random_output = torch.randn(3,3)
criterion = nn.MSELoss()
# model
l1 = nn.Linear(3, 10, bias=False)
l2 = nn.Linear(10, 3, bias=False)
intermediate = l1(random_input)
# intermediate.requires_grad=False  RuntimeError
# intermediate = intermediate.detach()
output = l2(intermediate)

loss = criterion(output ,random_output)

I think what you confused is about intermediate result and l1.weight. requires_grad=False is set to l1.weight and what the chain rule needs is the intermediate result.

And if you want to set requires_grad=False to intermediate, it will raise an error

 RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

So if we want to change requires_grad flags of intermediate variables, we should use .detach(), and then the related sub_graph will not have gradients flow back.

Using the snippet above, it will display l1.weight.grad = None and l2.weight.grad computed normally.


Yes that was my issue - I did not mentally separate weights from operators.

Thank you for replying.