Gradients calculations only till certain layer

Hello,

Lets say I have 6 Transformers based model. I would like to only calculate gradients till 2nd layers (I mean backpass from layer 6th → 5th → 4th → 3th). I don’t want to calculate the gradients for 1st and 2nd layers, backpass stops at 3rd layers. And “This number of layers to participate in backpass varies with each minibatch”. “And I do forward pass for all the layers.” Would anyone help me with?

Correct me if I am wrong, If I use params.requires_grad = False, does it compute the gradients but does not updates those params. Is this the case?

Please help me with the how do I perform the above operation.

Thank you in advance.

No, if you are freezing parameters and no other parameters or inputs previously used in the model require gradients Autograd won’t compute these.

Here is a simple example:

model = nn.Sequential(
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
        nn.Conv2d(3, 3, 3),
).to(device)
x = torch.randn(16, 3, 24, 24, device=device)

# for i in range(9):
#     model[i].weight.requires_grad = False
#     model[i].bias.requires_grad = False

out = model(x)
out.mean().backward()

Profiling this code shows 10 wgrad and 9 dgrad kernels:

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     37.2          111,233         10  11,123.3   9,792.0     2,623    23,520      7,101.1  void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…
     20.3           60,799         10   6,079.9   6,064.0     5,632     6,784        339.9  void cudnn::cnn::conv2d_grouped_direct_kernel<(bool)0, (bool)1, (bool)0, (bool)0, (bool)0, (bool)0,…
     17.0           50,880          9   5,653.3   5,632.0     5,536     5,760         73.3  void cudnn::cnn::dgrad2d_grouped_direct_kernel<float, int, float, float, (bool)0, (bool)1, (int)0, …
     15.4           46,113         10   4,611.3   4,272.0     3,488     6,592      1,073.4  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::func_wrapp…
      7.1           21,185         10   2,118.5   2,016.0     1,952     3,136        359.1  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      1.1            3,232          1   3,232.0   3,232.0     3,232     3,232          0.0  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
      0.9            2,817          1   2,817.0   2,817.0     2,817     2,817          0.0  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      0.6            1,920          1   1,920.0   1,920.0     1,920     1,920          0.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      0.4            1,151          1   1,151.0   1,151.0     1,151     1,151          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, std::array<c…

Now if we freeze all but the last layer, we will see a single wgrad kernel:

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     63.1           61,216         10   6,121.6   6,064.0     5,696     6,752        294.5  void cudnn::cnn::conv2d_grouped_direct_kernel<(bool)0, (bool)1, (bool)0, (bool)0, (bool)0, (bool)0,…
     21.5           20,895         10   2,089.5   2,016.0     1,952     2,816        257.8  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      3.4            3,264          1   3,264.0   3,264.0     3,264     3,264          0.0  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::func_wrapp…
      3.3            3,200          1   3,200.0   3,200.0     3,200     3,200          0.0  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
      2.8            2,752          1   2,752.0   2,752.0     2,752     2,752          0.0  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      2.7            2,656          1   2,656.0   2,656.0     2,656     2,656          0.0  void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…
      1.9            1,824          1   1,824.0   1,824.0     1,824     1,824          0.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      1.2            1,184          1   1,184.0   1,184.0     1,184     1,184          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, std::array<c…

If we unfreeze the last two layers, we will see 2 wgrad and 1 dgrad kernel:

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     55.0           60,706         10   6,070.6   6,032.5     5,600     6,721        311.2  void cudnn::cnn::conv2d_grouped_direct_kernel<(bool)0, (bool)1, (bool)0, (bool)0, (bool)0, (bool)0,…
     19.2           21,184         10   2,118.4   2,000.0     1,920     3,103        350.8  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      6.8            7,456          2   3,728.0   3,728.0     3,424     4,032        429.9  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::func_wrapp…
      5.9            6,496          2   3,248.0   3,248.0     2,688     3,808        792.0  void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…
      5.1            5,600          1   5,600.0   5,600.0     5,600     5,600          0.0  void cudnn::cnn::dgrad2d_grouped_direct_kernel<float, int, float, float, (bool)0, (bool)1, (int)0, …
      2.8            3,137          1   3,137.0   3,137.0     3,137     3,137          0.0  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
      2.5            2,784          1   2,784.0   2,784.0     2,784     2,784          0.0  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      1.7            1,856          1   1,856.0   1,856.0     1,856     1,856          0.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      1.1            1,184          1   1,184.0   1,184.0     1,184     1,184          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, std::array<c…

Let’s say something like this.

lets say I got features at each intermediate layer and logits

(1) Features, logits = Model(inputs)

(2) List1 = Get_layers(Features)
#list1 gives upto which layer from layer 1 has not needed to participate in backpass.

(3) Loss = lossfn(logits, labels)

(4) Loss.backward()

Lets say at (1) all the layers have require grads = True. Now list1 gives me 1,2 (which means 1,2 layers does not have to calculate gradients at all). After (3) if I set layer[1,2].param. require_grad = False.

After step (4) does it still calculate gradients for layer 1,2 and not update layer 1,2 ? Or its does not calculate grads at all?

If possible could you please explain what is this wgrad and dgrad please.