Conv1d: Requirements for nondeterministic algorithm usage

I noticed some big differences in the runtime of Conv1d when using different values for out_channels. Below you can see three tests I ran with different out_channel sizes.

import torch
import timeit

class CNN(torch.nn.Module):
    def __init__(
        self,in_dim, out_dim,kernel_size,
    ):
        super().__init__()
        self.conv_1 = torch.nn.Conv1d(in_dim, out_dim, kernel_size)
    def forward(self, x):
        return self.conv_1(x)

def test_run(x,cnn):
    output = cnn(x).sum()
    output.backward()

device = "cuda:1"
x = torch.randn(1000,77,400).to(device)

cnn_120 = CNN(77,120,2,).to(device)
cnn_100 = CNN(77,100,2,).to(device)
cnn_50 = CNN(77,50,2,).to(device)

timeit.timeit("test_run(x,cnn_120)",globals=globals(),number=500)
# Output: 18.364156678318977
timeit.timeit("test_run(x,cnn_100)",globals=globals(),number=500)
# Output: 3.4758292948827147
timeit.timeit("test_run(x,cnn_50)",globals=globals(),number=500)
# Output: 23.5902945268899

My assumption is that cnn_100 is so much faster, because it uses the nondeterministic algorithm. The documentation of Conv1d says the following: In some circumstances when using the CUDA backend with CuDNN, this operator may select a nondeterministic algorithm to increase performance.

My questions is now: What requirements has my network to fulfill, so that the nondeterministic algorithm is used?

CUDA operations are executed asynchronously, so you would have to synchronize the code before starting and stopping the timer. Alternatively, you could also use the torch.utils.benchmark utility, which would synchronize the code for you and also add warmup iterations for a proper profiling.

Thanks for letting me know about torch.utils.benchmark. I constructed a new benchmark with this. I also adjusted the settings so that the convolution should be deterministic.

import torch
import torch.utils.benchmark as benchmark
from itertools import product

class CNN(torch.nn.Module):
    def __init__(
        self,in_dim, out_dim,kernel_size,
    ):
        super().__init__()
        self.conv_1 = torch.nn.Conv1d(in_dim, out_dim, kernel_size)
    def forward(self, x):
        return self.conv_1(x)

def test_cnn(x,cnn):
    return cnn(x).sum().backward()

torch.backends.cudnn.benchmark = False
torch.set_deterministic(True)

device = "cuda:0"
x1 = torch.randn(256,77,126).to(device)
x2 = torch.randn(211,77,137).to(device)
x3 = torch.randn(512,77,221).to(device)

out_dims = [50, 80, 100, 120, 150]
kernel_sizes = [2, 3, 4]
results = []

for out_dim, kernel_size in product(out_dims, kernel_sizes):
    sub_label = f"[77, {out_dim}, {kernel_size}]"
    cnn = CNN(77, out_dim, kernel_size).to(device)
    for num_threads in [1, 4]:
        results.append(benchmark.Timer(
            stmt="test_cnn(x,cnn)",
            setup="from __main__ import test_cnn",
            globals={"x": x1, "cnn": cnn},
            num_threads=num_threads,
            sub_label=sub_label,
            description="x1"
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt="test_cnn(x,cnn)",
            setup="from __main__ import test_cnn",
            globals={"x": x2, "cnn": cnn},
            num_threads=num_threads,
            sub_label=sub_label,
            description="x2"
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt="test_cnn(x,cnn)",
            setup="from __main__ import test_cnn",
            globals={"x": x3, "cnn": cnn},
            num_threads=num_threads,
            sub_label=sub_label,
            description="x3"
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()

Here is the the output that I get

[--------------------  -------------------]
                    |   x1  |   x2  |   x3 
1 threads: --------------------------------
      [77, 50, 2]   |  4.7  |  4.2  |  16.7
      [77, 50, 3]   |  4.7  |  4.1  |  16.6
      [77, 50, 4]   |  4.5  |  4.0  |  15.9
      [77, 80, 2]   |  4.4  |  3.9  |  15.6
      [77, 80, 3]   |  4.3  |  3.8  |  15.4
      [77, 80, 4]   |  4.3  |  3.8  |  14.9
      [77, 100, 2]  |  3.5  |  3.2  |  12.1
      [77, 100, 3]  |  3.5  |  3.2  |  12.3
      [77, 100, 4]  |  3.6  |  3.2  |  12.4
      [77, 120, 2]  |  3.6  |  3.2  |  12.3
      [77, 120, 3]  |  3.6  |  3.2  |  12.4
      [77, 120, 4]  |  3.6  |  3.2  |  12.5
      [77, 150, 2]  |  5.1  |  4.5  |  18.1
      [77, 150, 3]  |  5.0  |  4.5  |  18.0
      [77, 150, 4]  |  4.9  |  4.4  |  17.7
4 threads: --------------------------------
      [77, 50, 2]   |  4.7  |  4.2  |  16.7
      [77, 50, 3]   |  4.7  |  4.1  |  16.6
      [77, 50, 4]   |  4.5  |  4.0  |  15.9
      [77, 80, 2]   |  4.4  |  3.9  |  15.6
      [77, 80, 3]   |  4.3  |  3.8  |  15.4
      [77, 80, 4]   |  4.3  |  3.8  |  14.9
      [77, 100, 2]  |  3.5  |  3.2  |  12.1
      [77, 100, 3]  |  3.5  |  3.2  |  12.3
      [77, 100, 4]  |  3.6  |  3.2  |  12.4
      [77, 120, 2]  |  3.6  |  3.2  |  12.3
      [77, 120, 3]  |  3.6  |  3.2  |  12.4
      [77, 120, 4]  |  3.6  |  3.2  |  12.5
      [77, 150, 2]  |  5.1  |  4.5  |  18.1
      [77, 150, 3]  |  5.0  |  4.5  |  18.0
      [77, 150, 4]  |  4.9  |  4.4  |  17.7

Times are in milliseconds (ms).

I find these results quite confusing. I was expecting for the runtime to increase with the size of out_channels. However, the results indicate that there is some kind of sweet spot around 100/120. Also, for out_channels 100 and 120, an increase in the kernel_size leads to an increase in runtime. While for the other out_channels sizes the runtime decreases as kernel_size increases. Is there some explanation for these results?

cudnn might use different kernels for different setups (input shape, conv setup). In your example, a matrix multiplication kernel should be used and on my machine these kernels are used:

CUDA Kernel Statistics:

 Time(%)  Total Time (ns)  Instances   Average   Minimum  Maximum                                                  Name
 -------  ---------------  ---------  ---------  -------  -------  ----------------------------------------------------------------------------------------------------
    34.1      43451404533      55060   789164.6   625733  3598010  void cudnn::cnn::wgrad_alg1_engine<float, float, 128, 6, 7, 3, 3, 5, false, true>(int, int, int, fl…
    23.0      29278562879       4976  5883955.6  5541866  6217167  volta_scudnn_128x128_stridedB_splitK_interior_nn_v1
    11.6      14821780890      66001   224569.0    64480   821926  void at::native::unrolled_elementwise_kernel<at::native::AddFunctor<float>, at::detail::Array<char*…
     7.6       9615588408       4510  2132059.5  2007375  2843061  void cudnn::cnn::wgrad_alg1_engine<float, float, 128, 6, 8, 3, 3, 5, false, true>(int, int, int, fl…
     5.0       6321303140      16508   382923.6   101057  1380074  volta_scudnn_128x32_relu_interior_nn_v1
     3.7       4717783259       1455  3242462.7  3099383  3428283  volta_scudnn_128x64_stridedB_splitK_interior_nn_v1
     3.3       4150389976      20185   205617.5    61825   589796  volta_scudnn_128x64_relu_medium_nn_v1
     3.3       4143596964      66001    62780.8    17121   280162  void at::native::unrolled_elementwise_kernel<at::native::copy_device_to_device(at::TensorIterator&,…
     2.7       3455464856     132002    26177.4     9376   172546  void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::native::func_wrapper_t<float…
     2.4       3045423927      17554   173488.9   101440   286466  volta_scudnn_128x32_relu_medium_nn_v1
     1.9       2402193733       8310   289072.7   197633   759462  volta_scudnn_128x128_relu_small_nn_v1
     1.0       1244816916       3444   361445.1   312547   421859  void implicit_convolve_sgemm<float, float, 128, 5, 5, 3, 3, 3, 1, false, false, true>(int, int, int…
     0.2        309497449     131972     2345.2     1439   301442  void at::native::vectorized_elementwise_kernel<4, at::native::AddFunctor<float>, at::detail::Array<…
     0.1        149821950      62557     2395.0     1599   172257  void cask_cudnn::computeOffsetsKernel<false, false>(cask_cudnn::ComputeOffsetsParams)
     0.1        143100153      66001     2168.2     1248   267458  void at::native::vectorized_elementwise_kernel<4, at::native::FillFunctor<float>, at::detail::Array…
     0.0         17851830       6431     2775.9     1663    44640  cask_cudnn::computeWgradSplitKOffsetsKernel(cask_cudnn::ComputeSplitKOffsetsParams)
     0.0         12621596       6431     1962.6     1472    45025  cask_cudnn::computeWgradBOffsetsKernel(cask_cudnn::ComputeWgradBOffsetsParams)

I also get the following results on a TitanV using a master build with cudnn8.1:

[-------------------  -------------------]
                    |   x1  |   x2  |   x3
1 threads: -------------------------------
      [77, 50, 2]   |  4.3  |  3.3  |  2.7
      [77, 50, 3]   |  1.1  |  1.1  |  2.0
      [77, 50, 4]   |  1.1  |  1.1  |  2.1
      [77, 80, 2]   |  6.6  |  5.9  |  3.1
      [77, 80, 3]   |  1.1  |  1.1  |  2.4
      [77, 80, 4]   |  1.1  |  1.0  |  2.5
      [77, 100, 2]  |  6.6  |  5.9  |  3.4
      [77, 100, 3]  |  1.1  |  1.1  |  2.7
      [77, 100, 4]  |  1.2  |  1.1  |  2.8
      [77, 120, 2]  |  6.7  |  6.0  |  3.5
      [77, 120, 3]  |  1.2  |  1.1  |  2.8
      [77, 120, 4]  |  1.2  |  1.1  |  3.0
      [77, 150, 2]  |  6.8  |  6.2  |  4.0
      [77, 150, 3]  |  1.3  |  1.2  |  3.7
      [77, 150, 4]  |  1.4  |  1.3  |  4.9
4 threads: -------------------------------
      [77, 50, 2]   |  3.7  |  3.3  |  2.7
      [77, 50, 3]   |  1.1  |  1.1  |  2.0
      [77, 50, 4]   |  1.1  |  1.1  |  2.1
      [77, 80, 2]   |  6.6  |  5.9  |  3.1
      [77, 80, 3]   |  1.1  |  1.1  |  2.4
      [77, 80, 4]   |  1.1  |  1.1  |  2.5
      [77, 100, 2]  |  6.6  |  6.0  |  3.4
      [77, 100, 3]  |  1.1  |  1.1  |  2.7
      [77, 100, 4]  |  1.2  |  1.1  |  2.8
      [77, 120, 2]  |  6.7  |  6.0  |  3.5
      [77, 120, 3]  |  1.2  |  1.1  |  2.8
      [77, 120, 4]  |  1.2  |  1.1  |  3.0
      [77, 150, 2]  |  6.8  |  6.2  |  4.0
      [77, 150, 3]  |  1.3  |  1.2  |  3.7
      [77, 150, 4]  |  1.4  |  1.3  |  4.8

Times are in milliseconds (ms).

Thanks for taking the time to have a look at my problem. I reinstalled pytorch and cudatoolkit cleanly in a new environment. Now my training times behave as expected.