How does requires_grad work?

We usually freeze part of network when backward by setting requires_grad = False, but if we set requires_grad = False, there is no gradient in this freezed layer, and it can’t back propagation by chain rule right? here is an example:
This is a 2 layer network, forward like this:
layer1 → layer2 → loss
if we set params in layer2 with requires_grad=False, it will not compute gradient in layer2, but to compute the gradient in layer1 will use the gradient from layer2 by chain rule, since we set requires_grad = False in layer2, we can’t derive the grad in layer1 right? please point out if i was wrong.

No, since you are only disabling the gradient computation for the explicit parameter (e.g. the wgrad for the weight), while PyTorch will compute the input gradient (dgrad) if a previous layer needs it.

Thanks for your quick reply!
So, If I set requires_grad = False in layer2, it still compute the gradient in layer2 since layer1 need previous layer’s gradient. After call loss.backward(), only params with requires_grad = True will store the gradient in weight.grad, and then then optimizer.step() will update the params which has grad. requires_grad = False will reduce memory and stop updating params. Is my understanding right?

No, that’s not correct since unneeded gradients also won’t be computed at all as seen here:

import torch
import torch.nn as nn


model = nn.Sequential(
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3)
).cuda()

#model[1].weight.requires_grad = False
#model[1].bias.requires_brad = False

x = torch.randn(1, 3, 24, 24).cuda()
out = model(x)
out.mean().backward()

In the default use case, when all parameters are trainable you will see 2 dgrad and 3 wgrad kernels calls in a profiler:

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     21.2            16544          2    8272.0    8272.0      8256      8288         22.6  void cudnn::cnn::dgrad2d_grouped_direct_kernel<float, int, float, float, (bool)0, (bool)1, (int)0, …
     20.9            16288          3    5429.3    4896.0      4768      6624       1036.6  void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…
     19.8            15456          3    5152.0    4928.0      4768      5760        532.6  void implicit_convolve_sgemm<float, float, (int)128, (int)5, (int)5, (int)3, (int)3, (int)3, (int)1…

while freezing the intermediate layer shows that one wgrad kernel is missing as it’s not needed anymore:

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     22.2            15968          2    7984.0    7984.0      7968      8000         22.6  void cudnn::cnn::dgrad2d_grouped_direct_kernel<float, int, float, float, (bool)0, (bool)1, (int)0, …
     21.7            15552          3    5184.0    4896.0      4768      5888        613.0  void implicit_convolve_sgemm<float, float, (int)128, (int)5, (int)5, (int)3, (int)3, (int)3, (int)1…
     15.2            10944          2    5472.0    5472.0      4736      6208       1040.9  void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…