torch.cudnn.benchmark=True picks slower algorithm

I am getting worse profiling results with cudnn.benchmark = True on my toy example (see below), and was wondering if this is user error or a bug / incompatible build?
My understanding is that, esp. in combination with cudnn.benchmark_limit = 0, cudnn should brute-force through absolutely all available choices and pick the fastest one?

Other slowdown-related topics (but no match)

System information:

  • GPU: 1x NVIDIA L4
  • Driver Version: 560.35.03
  • Ubuntu 24.04 (Docker)
  • Python 3.11.11
  • Pytorch 2.5.1 (built from source):
    • CUDA 12.6
    • cuDNN 9.6.0.74-1

Pytorch details:

print(torch.__config__.show())
PyTorch built with:
  - GCC 13.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2025.0.1-Product Build 20241031 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.5.3 (Git Hash 66f0cb9eb66affd2da3bf5f8d897376f04aae6af)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.6
  - NVCC architecture flags: -gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_89,code=sm_89
  - CuDNN 90.6
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.6, CUDNN_VERSION=9.6.0, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS=-O2 -pipe -D_GLIBCXX_USE_CXX11_ABI=1 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.5.0, USE_CUDA=1, USE_CUDNN=ON, USE_CUSPARSELT=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

profile.py

import torch
from torch.backends import cudnn

kernel = torch.rand(64, 3, 3, 3, device='cuda')

cudnn.benchmark = False
cudnn.benchmark_limit = 0
torch.cuda.synchronize()

for _ in range(50):
    data = torch.rand(64,3,224,224, device='cuda')
    _ = torch.nn.functional.conv2d(data, kernel)
torch.cuda.synchronize()

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
) as prof:
        for _ in range(100):
            data = torch.rand(64,3,224,224, device='cuda')
            _ = torch.nn.functional.conv2d(data, kernel)

print(prof.key_averages().table(sort_by='self_cuda_time_total', row_limit=10))

Output (cudnn.benchmark=False):

python cudnn_bench.py
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::cudnn_convolution         0.58%       2.199ms         0.92%       3.488ms      34.876us     362.741ms        95.90%     362.741ms       3.627ms           100  
_5x_cudnn_ampere_scudnn_128x64_relu_xregs_large_nn_v...         0.00%       0.000us         0.00%       0.000us       0.000us     362.427ms        95.82%     362.427ms       3.624ms           100  
                                         aten::uniform_         0.28%       1.048ms         0.50%       1.914ms      19.137us      15.509ms         4.10%      15.509ms     155.086us           100  
void at::native::(anonymous namespace)::distribution...         0.00%       0.000us         0.00%       0.000us       0.000us      15.509ms         4.10%      15.509ms     155.086us           100  
void cask__5x_cudnn::computeOffsetsKernel<false, fal...         0.00%       0.000us         0.00%       0.000us       0.000us     313.564us         0.08%     313.564us       3.136us           100  
                                             aten::rand         0.10%     382.634us         1.42%       5.401ms      54.012us       0.000us         0.00%      15.509ms     155.086us           100  
                                            aten::empty         0.82%       3.105ms         0.82%       3.105ms      31.049us       0.000us         0.00%       0.000us       0.000us           100  
                                  cudaStreamIsCapturing         0.03%     119.870us         0.03%     119.870us       0.599us       0.000us         0.00%       0.000us       0.000us           200  
                                       cudaLaunchKernel         0.52%       1.984ms         0.52%       1.984ms       6.615us       0.000us         0.00%       0.000us       0.000us           300  
                                           aten::conv2d         0.04%     134.365us         1.11%       4.220ms      42.204us       0.000us         0.00%     362.741ms       3.627ms           100  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 379.873ms
Self CUDA time total: 378.249ms

Output (cudnn.benchmark=True and cudnn.benchmark_limit=0)

python cudnn_bench.py
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::cudnn_convolution         0.40%       2.042ms         0.65%       3.291ms      32.912us     492.278ms        97.28%     492.278ms       4.923ms           100  
       _5x_cudnn_ampere_scudnn_128x128_relu_small_nn_v1         0.00%       0.000us         0.00%       0.000us       0.000us     491.961ms        97.22%     491.961ms       4.920ms           100  
                                         aten::uniform_         0.20%       1.018ms         0.37%       1.885ms      18.845us      13.759ms         2.72%      13.759ms     137.587us           100  
void at::native::(anonymous namespace)::distribution...         0.00%       0.000us         0.00%       0.000us       0.000us      13.759ms         2.72%      13.759ms     137.587us           100  
void cask__5x_cudnn::computeOffsetsKernel<false, fal...         0.00%       0.000us         0.00%       0.000us       0.000us     317.248us         0.06%     317.248us       3.172us           100  
                                             aten::rand         0.06%     295.520us         1.04%       5.258ms      52.584us       0.000us         0.00%      13.759ms     137.587us           100  
                                            aten::empty         0.61%       3.078ms         0.61%       3.078ms      30.784us       0.000us         0.00%       0.000us       0.000us           100  
                                  cudaStreamIsCapturing         0.02%     118.718us         0.02%     118.718us       0.594us       0.000us         0.00%       0.000us       0.000us           200  
                                       cudaLaunchKernel         0.38%       1.945ms         0.38%       1.945ms       6.484us       0.000us         0.00%       0.000us       0.000us           300  
                                           aten::conv2d         0.02%     123.625us         0.79%       3.994ms      39.936us       0.000us         0.00%     492.278ms       4.923ms           100  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 507.724ms
Self CUDA time total: 506.037ms

Additional info:
I noticed that I get the fast algorithm in both cases if I add one more call of this convolution at the beginning of the script, before setting cudnn.benchmark = True.
Does the benchmark itself perhaps also require some sort of warmup? I.e. did it ignore the first (fastest) choice because some benchmark-related setup “polluted” that measurement?

Which GPU are you using?

Oops, thanks for pointing that out. It’s a GCP g2-standard-8 instance with 1x NVIDIA L4 - added it to the System information at the top.

OK, looking at the nsys profile results for the benchmark section, I can see what happened - but not why:


Each algorithm only gets measured twice and based on the first two measurements of each algorithm, relu_small_nn_v1 was definitely better:

# Name Start Duration
1 _5x_cudnn_ampere_scudnn_128x64_relu_xregs_large_nn_v1 2.35358s 3.611 ms
2 _5x_cudnn_ampere_scudnn_128x64_relu_xregs_large_nn_v1 2.35735s 3.609 ms
1 _5x_cudnn_ampere_scudnn_128x128_relu_small_nn_v1 2.36117s 3.527 ms
2 _5x_cudnn_ampere_scudnn_128x128_relu_small_nn_v1 2.36479s 3.494 ms
49 _5x_cudnn_ampere_scudnn_128x128_relu_small_nn_v1 3.49518s 5.771 ms
50 _5x_cudnn_ampere_scudnn_128x128_relu_small_nn_v1 3.5011s 5.770 ms

So now I have even more questions:

  • Why does the time vary so much for the first ~30 iterations - it happens every time I run the script, not some occasional, temporary slowdown
  • Is there a way to increase the number of tries cudnn.benchmark will perform?

Update:
Yes, there is a way to change the cuDNN benchmark behavior, found in:

Currently, the sampling technique is hardcoded to CUDNN_FIND_SAMPLE_ONCE, but there are also CUDNN_FIND_SAMPLE_MEDIAN_OF_THREE and CUDNN_FIND_SAMPLE_TILL_STABLE.
Also, I take the existence of the latter as confirmation that the wobbly performance in the first 30 iterations is not unexpected, either…

I’m not sure the root cause of the observed slowdown is actually benchmarking inaccuracy–it’s a bit unusual of a scenario but it seems that the current script in conjunction on L4 (being only 72W power limit?) might be throttling back and slowing down more in the benchmarking case.

Namely I observe that reducing the number of “warmup” iterations in the script along with the number of timing iterations brings the benchmark=True and benchmark=False cases much closer together. I suspect that this is because benchmark=True is more liable to saturate the power and/or thermal limits of the GPU (due to the additional kernels being run) such that the actual profiled iterations are slowed down. Additionally note that cuDNN will only benchmark/autotune in the very first warmup iteration, the following warmup iterations will not affect the kernel selection itself.

Finally for future experimentation I did open a PR to make the cuDNN benchmarking technique user-controllable: [CUDNN][CUDNN V8 API] Allow user-specified CUDNN V8 API benchmarking technique by eqy · Pull Request #145779 · pytorch/pytorch

1 Like

It looks like that is exactly what is happening. I added a 100ms sleep after each iteration to prevent any power/temperature related throttling and now the measurements for each iteration are consistent (and cudnn.benchmark=true picks the best choice assuming no external throttling):

cudnn.benchmark= Instances Avg Med Min Max StdDev Name
true 54 3,523 ms 3,511 ms 3,482 ms 4,101 ms 80,590 ÎĽs _5x_cudnn_ampere_scudnn_128x128_relu_small_nn_v1
false 51 3,638 ms 3,638 ms 3,606 ms 3,646 ms 6,006 ÎĽs _5x_cudnn_ampere_scudnn_128x64_relu_xregs_large_nn_v1

Next, I want to check if I also can get benchmark to pick the best choice for the throttled case. Will update later.

Thanks so much, @eqy!

Update
Getting benchmark to pick a more “heat-resistant” algorithm works, but only if I run my ~200 warmup iterations before enabling cudnn.benchmark with a slightly differently shaped input tensor, otherwise nothing happens, i.e. new iterations with that convolution won’t trigger sampling of new algorithms, it just continues to use the default one… maybe a bug.

Manually throttling the GPU should give more consistent results?

Maybe there is still something else going on than just power/temperature related throttling? I manually reduced the power limit of that L4 from 72W to 40W and still see the same pattern. nvidia-smi reports idle temperatures of 76°C. I saw a short spike to 80°C at 72W, and 77°C at 40W.

Repeated the run a few hours later when I saw idle temperature was down to ~55°C and the measurments jumped from 4ms to 22ms… I don’t know what it is, but I am now fairly sure it’s not Pytorch related.