ConvTranspose1d extremely slow on GPU (T4), slower than CPU

Hi,

I’m confused that torch.nn.ConvTranspose1d is extremely slow when running on GPU, even slower than CPU.

Code to reproduce:

$ cat test_trans_conv.py

import torch
x = torch.randn(1, 64, 40000)
if torch.cuda.is_available():
    x = x.cuda()
trans_conv = torch.nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)
if torch.cuda.is_available():
    trans_conv.to('cuda')
    
import time
num = 100
with torch.no_grad():
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    start = time.time()
    for i in range(num):
        y = trans_conv(x)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    end = time.time()
    print('average cost: {}ms'.format((end - start) * 1000 / num))

GPU test on T4(16GB):
OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0 python test_trans_conv.py

average cost: 61.4864182472229ms

CPU test:
OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES="" python test_trans_conv.py

average cost: 19.696030616760254ms

However when running on V100(32GB), everything works fine.

GPU test on V100(32GB):
OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0 python test_trans_conv.py

average cost: 5.709564685821533ms

Is there some optimizations done on V100 that do not apply to T4?

Environment:

  • pytorch 1.5.1
  • torchvision 0.6.1
  • GPU: T4 16GB
  • CPU: Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz

nvidia-smi output:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.33.01    Driver Version: 440.33.01    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla T4            On   | 00000000:00:09.0 Off |                    0 |
| N/A   26C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

lspci output:

00:09.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)

nvprof output:

==44916== NVPROF is profiling process 44916, command: python test_trans_conv.py
==44916== Profiling application: python test_trans_conv.py
==44916== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   99.75%  6.14003s       100  61.400ms  60.137ms  112.24ms  void cudnn::detail::dgrad2d_alg1_1<float, int=0, int=6, int=6, int=5, int=4, int=4, bool=1, bool=1>(int, int, int, float const *, int, float const , int, cudnn::detail::dgrad2d_alg1_1<float, int=0, int=6, int=6, int=5, int=4, int=4, bool=1, bool=1>*, kernel_grad_params, int, int, float, int, int)
                    0.14%  8.7692ms       100  87.692us  87.140us  94.823us  _ZN2at6native6legacy18elementwise_kernelILi128ELi4EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE2_clEvEUlffE_EEvS5_RKT_EUliE2_EEviT1_
                    0.07%  4.4278ms       100  44.277us  43.410us  51.797us  void scalePackedTensor_kernel<float, float>(cudnnTensor4dStruct, float*, float)
                    0.03%  1.9988ms         4  499.70us  1.0250us  1.9928ms  [CUDA memcpy HtoD]
                    0.00%  3.7460us         4     936ns     896ns     993ns  [CUDA memset]
      API calls:   50.69%  6.14258s         2  3.07129s  20.710us  6.14256s  cudaDeviceSynchronize
                   49.22%  5.96409s        11  542.19ms  5.4380us  5.96339s  cudaMalloc
                    0.02%  2.8366ms       300  9.4550us  7.1300us  30.460us  cudaLaunchKernel
                    0.02%  2.2235ms         3  741.17us  9.2310us  2.1946ms  cudaMemcpyAsync
                    0.01%  1.0543ms         2  527.13us  523.45us  530.81us  cuDeviceTotalMem
                    0.01%  950.94us         1  950.94us  950.94us  950.94us  cudaHostAlloc
                    0.00%  595.08us      1734     343ns     264ns  4.8750us  cudaGetDevice
                    0.00%  456.79us      1111     411ns     329ns  13.294us  cudaSetDevice
                    0.00%  376.66us       200  1.8830us  1.2390us  13.603us  cudaBindTexture
                    0.00%  360.93us       191  1.8890us     135ns  78.127us  cuDeviceGetAttribute
                    0.00%  323.61us         2  161.80us  161.70us  161.90us  cudaGetDeviceProperties
                    0.00%  154.79us         4  38.696us  1.9710us  148.11us  cudaStreamCreateWithPriority
                    0.00%  140.18us       100  1.4010us  1.2120us  4.6170us  cudaEventRecord
                    0.00%  136.54us       200     682ns     497ns  2.3080us  cudaUnbindTexture
                    0.00%  118.33us       169     700ns     431ns  3.7450us  cudaFuncSetAttribute
                    0.00%  86.265us         3  28.755us  6.1770us  69.597us  cudaStreamSynchronize
                    0.00%  71.269us         4  17.817us  8.4940us  43.461us  cudaMemsetAsync
                    0.00%  70.467us         2  35.233us  20.860us  49.607us  cuDeviceGetName
                    0.00%  47.772us       300     159ns     137ns     317ns  cudaGetLastError
                    0.00%  32.743us         8  4.0920us  1.9660us  17.648us  cudaStreamCreateWithFlags
                    0.00%  24.566us        30     818ns     483ns  3.1420us  cudaEventCreateWithFlags
                    0.00%  18.974us         1  18.974us  18.974us  18.974us  cudaMemcpy
                    0.00%  13.878us        29     478ns     271ns  3.5180us  cudaDeviceGetAttribute
                    0.00%  8.9030us        25     356ns     123ns  2.2330us  cudaGetDeviceCount
                    0.00%  5.7300us         1  5.7300us  5.7300us  5.7300us  cuDeviceGetPCIBusId
                    0.00%  4.6730us         4  1.1680us     439ns  2.4560us  cudaFree
                    0.00%  2.4920us         1  2.4920us  2.4920us  2.4920us  cudaDeviceGetStreamPriorityRange
                    0.00%  1.9250us         1  1.9250us  1.9250us  1.9250us  cudaHostGetDevicePointer
                    0.00%  1.5120us         4     378ns     127ns     906ns  cuDeviceGetCount
                    0.00%     908ns         3     302ns     169ns     524ns  cuDeviceGet
                    0.00%     747ns         1     747ns     747ns     747ns  cuInit
                    0.00%     477ns         2     238ns     216ns     261ns  cuDeviceGetUuid
                    0.00%     355ns         1     355ns     355ns     355ns  cuDriverGetVersion

cudnn::detail::dgrad2d_alg1_1 contributes 99.75% of the time , what does this kernel do and why does it cost so long?

How to debug and accelerate ConvTranspose1d?
Any suggestion will be appreciated, thanks!

Can someone please test this script on your machince to check whether you have the same issue?

Thanks in advance!

It seems a bad kernel is selected in the default setup by cudnn and you can use torch.backends.cudnn.benchmark = True to use the cudnn benchmark mode to select the fastest kernel.
In this mode the first iteration will be slower, as multiple algorithms will be executed to select the fastest one.

After setting torch.backends.cudnn.benchmark = True at the beginning of the script, I’ll get ~11.44ms.
The fast kernel seems to be already selected in the upcoming cudnn8 release.

Yes, after setting torch.backends.cudnn.benchmark = True at the beginning of the script, I got ~2.0008ms. Thank you for sharing this!

However, the model I use accept variable input instead of fixed one, so this result in searching the fastest every time it meets a different shape of input.

The fast kernel seems to be already selected in the upcoming cudnn8 release.

By the way, which version of PyTorch is packed with cudnn8? The latest stable version 1.5.1 comes with cudnn7.6.5.

Depending how many different input shapes you have, the initial slowdown might be OK or is every input shape a new one?

None so far, as we are currently creating PRs to enable cudnn8 in the source builds (cudnn8 isn’t released yet, just the RC :wink: ).

@ptrblck

Every input shape is a new one :joy:

Is there a way we can explicitly tell cudnn to select a faster kernel instead of the bad one?

Have you tried this script on cudnn8 without setting torch.backends.cudnn.benchmark = True?

By the way, do you know what does the kernel cudnn::detail::dgrad2d_alg1_1 do? I searched Google but couldn’t get any useful info. :laughing:

No, there is no user facing API to do so.

Yes, ran it with the default setup first and compared it to benchmark mode, which didn’t yield a speedup.

It’s a cudnn “data gradient” kernel, which computes the gradient for the input activation.

Thank you for your explanations!

How to setup pyorch with cudnn8? By compiling from source with system cudnn8 library setup?

Gradients are still computed even with torch.no_grad()?

Yes, this would be the current option.

No, but the backward pass through a “vanilla” convolution is a transposed convolution.
In the same manner, the forward pass of a transposed convolution can use the “backward kernel” of a vanilla convolution, which is why you are seeing the dgrad kernel here.

1 Like

By following instructions at the following url? GitHub - pytorch/pytorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration

It makes sense :wink:

1 Like

Yes, you could follow these instructions and install CUDA11 + cudnn8_RC locally.
If you would like to test these lib versions, you could also try out the 20.06 NGC docker container.