When you wish to not update (freeze) parts of the network, the recommended solution is to set
requires_grad = False, and/or (please confirm?) not send the parameters you wish to freeze to the optimizer input.
I would like to clarify that the
requires_grad = False simply avoids unnecessary computation, update, and storage of gradients at those nodes and does not create subgraphs which saves memory.
However, the parameters with
requires_grad = False will still contain a
grad_fn (if it has one) so that in the backward pass, the gradient from that node is still technically still calculated passed backward and the chain rule is still maintained and used for the parameters with
requires_grad = True?
So if we have layers L1 -> L2 -> LOSS and want to freeze L2, the gradient of LOSS wrt L1 requires the gradient of LOSS wrt L2 as per chain rule. My confusion is that as the name implies,
requires_grad = False on L2 will not compute the gradient, but we need it to update L1, no?
Please confirm Thank you
I tired a snippet as follows:
random_input = torch.randn(3,3)
random_output = torch.randn(3,3)
criterion = nn.MSELoss()
l1 = nn.Linear(3, 10, bias=False)
l2 = nn.Linear(10, 3, bias=False)
intermediate = l1(random_input)
# intermediate.requires_grad=False RuntimeError
# intermediate = intermediate.detach()
output = l2(intermediate)
loss = criterion(output ,random_output)
I think what you confused is about
intermediate result and
requires_grad=False is set to
l1.weight and what the chain rule needs is the
And if you want to set
intermediate, it will raise an error
RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().
So if we want to change requires_grad flags of intermediate variables, we should use
.detach(), and then the related sub_graph will not have gradients flow back.
Using the snippet above, it will display l1.weight.grad = None and l2.weight.grad computed normally.
Yes that was my issue - I did not mentally separate weights from operators.
Thank you for replying.