No, if you are freezing parameters and no other parameters or inputs previously used in the model require gradients Autograd won’t compute these.
Here is a simple example:
model = nn.Sequential(
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
nn.Conv2d(3, 3, 3),
).to(device)
x = torch.randn(16, 3, 24, 24, device=device)
# for i in range(9):
# model[i].weight.requires_grad = False
# model[i].bias.requires_grad = False
out = model(x)
out.mean().backward()
Profiling this code shows 10 wgrad
and 9 dgrad
kernels:
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
37.2 111,233 10 11,123.3 9,792.0 2,623 23,520 7,101.1 void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…
20.3 60,799 10 6,079.9 6,064.0 5,632 6,784 339.9 void cudnn::cnn::conv2d_grouped_direct_kernel<(bool)0, (bool)1, (bool)0, (bool)0, (bool)0, (bool)0,…
17.0 50,880 9 5,653.3 5,632.0 5,536 5,760 73.3 void cudnn::cnn::dgrad2d_grouped_direct_kernel<float, int, float, float, (bool)0, (bool)1, (int)0, …
15.4 46,113 10 4,611.3 4,272.0 3,488 6,592 1,073.4 void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::func_wrapp…
7.1 21,185 10 2,118.5 2,016.0 1,952 3,136 359.1 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
1.1 3,232 1 3,232.0 3,232.0 3,232 3,232 0.0 void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
0.9 2,817 1 2,817.0 2,817.0 2,817 2,817 0.0 void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
0.6 1,920 1 1,920.0 1,920.0 1,920 1,920 0.0 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
0.4 1,151 1 1,151.0 1,151.0 1,151 1,151 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, std::array<c…
Now if we freeze all but the last layer, we will see a single wgrad
kernel:
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
63.1 61,216 10 6,121.6 6,064.0 5,696 6,752 294.5 void cudnn::cnn::conv2d_grouped_direct_kernel<(bool)0, (bool)1, (bool)0, (bool)0, (bool)0, (bool)0,…
21.5 20,895 10 2,089.5 2,016.0 1,952 2,816 257.8 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
3.4 3,264 1 3,264.0 3,264.0 3,264 3,264 0.0 void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::func_wrapp…
3.3 3,200 1 3,200.0 3,200.0 3,200 3,200 0.0 void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
2.8 2,752 1 2,752.0 2,752.0 2,752 2,752 0.0 void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
2.7 2,656 1 2,656.0 2,656.0 2,656 2,656 0.0 void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…
1.9 1,824 1 1,824.0 1,824.0 1,824 1,824 0.0 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
1.2 1,184 1 1,184.0 1,184.0 1,184 1,184 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, std::array<c…
If we unfreeze the last two layers, we will see 2 wgrad
and 1 dgrad
kernel:
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
-------- --------------- --------- -------- -------- -------- -------- ----------- ----------------------------------------------------------------------------------------------------
55.0 60,706 10 6,070.6 6,032.5 5,600 6,721 311.2 void cudnn::cnn::conv2d_grouped_direct_kernel<(bool)0, (bool)1, (bool)0, (bool)0, (bool)0, (bool)0,…
19.2 21,184 10 2,118.4 2,000.0 1,920 3,103 350.8 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
6.8 7,456 2 3,728.0 3,728.0 3,424 4,032 429.9 void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::func_wrapp…
5.9 6,496 2 3,248.0 3,248.0 2,688 3,808 792.0 void cudnn::cnn::wgrad2d_grouped_direct_kernel<(bool)0, (bool)1, int, float, float, float>(cudnn::c…
5.1 5,600 1 5,600.0 5,600.0 5,600 5,600 0.0 void cudnn::cnn::dgrad2d_grouped_direct_kernel<float, int, float, float, (bool)0, (bool)1, (int)0, …
2.8 3,137 1 3,137.0 3,137.0 3,137 3,137 0.0 void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
2.5 2,784 1 2,784.0 2,784.0 2,784 2,784 0.0 void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
1.7 1,856 1 1,856.0 1,856.0 1,856 1,856 0.0 void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
1.1 1,184 1 1,184.0 1,184.0 1,184 1,184 0.0 void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, std::array<c…