When you wish to not update (freeze) parts of the network, the recommended solution is to set requires_grad = False
, and/or (please confirm?) not send the parameters you wish to freeze to the optimizer input.
I would like to clarify that the requires_grad = False
simply avoids unnecessary computation, update, and storage of gradients at those nodes and does not create subgraphs which saves memory.
However, the parameters with requires_grad = False
will still contain a grad_fn
(if it has one) so that in the backward pass, the gradient from that node is still technically still calculated passed backward and the chain rule is still maintained and used for the parameters with requires_grad = True
?
So if we have layers L1 -> L2 -> LOSS and want to freeze L2, the gradient of LOSS wrt L1 requires the gradient of LOSS wrt L2 as per chain rule. My confusion is that as the name implies, requires_grad = False
on L2 will not compute the gradient, but we need it to update L1, no?
Please confirm Thank you
7 Likes
Hi, @likethevegetable
I tired a snippet as follows:
random_input = torch.randn(3,3)
random_output = torch.randn(3,3)
criterion = nn.MSELoss()
# model
l1 = nn.Linear(3, 10, bias=False)
l2 = nn.Linear(10, 3, bias=False)
intermediate = l1(random_input)
# intermediate.requires_grad=False RuntimeError
# intermediate = intermediate.detach()
output = l2(intermediate)
loss = criterion(output ,random_output)
loss.backward()
I think what you confused is about intermediate result
and l1.weight
. requires_grad=False
is set to l1.weight
and what the chain rule needs is the intermediate result
.
And if you want to set requires_grad=False
to intermediate
, it will raise an error
RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().
So if we want to change requires_grad flags of intermediate variables, we should use .detach()
, and then the related sub_graph will not have gradients flow back.
Using the snippet above, it will display l1.weight.grad = None and l2.weight.grad computed normally.
12 Likes
Yes that was my issue - I did not mentally separate weights from operators.
Thank you for replying.