We usually freeze part of network when backward by setting requires_grad = False, but if we set requires_grad = False, there is no gradient in this freezed layer, and it can’t back propagation by chain rule right? here is an example:
This is a 2 layer network, forward like this:
layer1 → layer2 → loss
if we set params in layer2 with requires_grad=False, it will not compute gradient in layer2, but to compute the gradient in layer1 will use the gradient from layer2 by chain rule, since we set requires_grad = False in layer2, we can’t derive the grad in layer1 right? please point out if i was wrong.
No, since you are only disabling the gradient computation for the explicit parameter (e.g. the wgrad
for the weight
), while PyTorch will compute the input gradient (dgrad
) if a previous layer needs it.
Thanks for your quick reply!
So, If I set requires_grad = False in layer2, it still compute the gradient in layer2 since layer1 need previous layer’s gradient. After call loss.backward()
, only params with requires_grad = True will store the gradient in weight.grad
, and then then optimizer.step()
will update the params which has grad. requires_grad = False will reduce memory and stop updating params. Is my understanding right?
No, that’s not correct since unneeded gradients also won’t be computed at all as seen here:
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3)
).cuda()
#model[1].weight.requires_grad = False
#model[1].bias.requires_brad = False
x = torch.randn(1, 3, 24, 24).cuda()
out = model(x)
out.mean().backward()
In the default use case, when all parameters are trainable you will see 2 dgrad
and 3 wgrad
kernels calls in a profiler:
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
21.2 16544 2 8272.0 8272.0 8256 8288 22.6 void cudnn::cnn::dgrad2d_grouped_direct_kernel<float, int, float, float, (bool)0, (bool)1, (int)0, …
20.9 16288 3 5429.3 4896.0 4768 6624 1036.6 void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…
19.8 15456 3 5152.0 4928.0 4768 5760 532.6 void implicit_convolve_sgemm<float, float, (int)128, (int)5, (int)5, (int)3, (int)3, (int)3, (int)1…
while freezing the intermediate layer shows that one wgrad
kernel is missing as it’s not needed anymore:
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
22.2 15968 2 7984.0 7984.0 7968 8000 22.6 void cudnn::cnn::dgrad2d_grouped_direct_kernel<float, int, float, float, (bool)0, (bool)1, (int)0, …
21.7 15552 3 5184.0 4896.0 4768 5888 613.0 void implicit_convolve_sgemm<float, float, (int)128, (int)5, (int)5, (int)3, (int)3, (int)3, (int)1…
15.2 10944 2 5472.0 5472.0 4736 6208 1040.9 void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…