Cudnn.allow_tf32 makes my network slower

I heard that setting cudnn.allow_tf32 to true (which is the default) can boost performance. However, this is not the case in my example.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math, time
torch.backends.cudnn.allow_tf32 = True

class MyBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.KL = nn.Conv3d(1, 27, kernel_size=3, padding='same')
        torch.manual_seed(0)
        self.reset_parameters()
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.KL.weight, a=math.sqrt(5))
        if self.KL.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.KL.weight)
            if fan_in != 0:
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(self.KL.bias, -bound, bound)
    def forward(self, image):
        K = self.KL(image) # 1 x N x N x N -> 27 x N x N x N
        return K.mean()

model = MyBlock().cuda()
N = 64
image = torch.randn(1, N, N, N).cuda()

for _ in range(10): # warm-up
    y = model(image)

iters = 100

forward = 0.0
backward = 0.0
for _ in range(iters):
    start = time.time()
    y = model(image)
    torch.cuda.synchronize()
    forward += time.time() - start

    start = time.time()
    y.sum().backward()
    torch.cuda.synchronize()
    backward += time.time() - start
print('\nForward: {:.3f} us | Backward {:.3f} us'.format(forward * 1e6/iters, backward * 1e6/iters))

On my NVIDIA GeForce RTX 3080, with allow_tf32 to true, it gives

Forward: 368.619 us | Backward 1425.776 us

with allow_tf32 to false:

Forward: 281.081 us | Backward 1096.125 us

Since your model is a single layer and thus tiny, could you profile the workload with nsys and post the reports here?

For allow_tf32 to false, I have

NVTX Range Statistics:

 Time (%)  Total Time (ns)  Instances    Avg (ns)      Med (ns)     Min (ns)    Max (ns)   StdDev (ns)    Style           Range         
 --------  ---------------  ---------  ------------  ------------  ----------  ----------  ------------  -------  ----------------------
     44.1       65,927,301         10   6,592,730.1   1,130,006.5     924,077  56,206,430  17,432,667.8  PushPop  backward              
     38.5       57,559,756          1  57,559,756.0  57,559,756.0  57,559,756  57,559,756           0.0  PushPop  iteration0            
      5.3        7,865,311         10     786,531.1     749,612.5     705,403   1,117,050     119,426.0  PushPop  forward               
      1.4        2,071,999          1   2,071,999.0   2,071,999.0   2,071,999   2,071,999           0.0  PushPop  iteration1            
      1.4        2,021,433          1   2,021,433.0   2,021,433.0   2,021,433   2,021,433           0.0  PushPop  iteration5            
      1.3        2,012,912          1   2,012,912.0   2,012,912.0   2,012,912   2,012,912           0.0  PushPop  iteration9            
      1.3        1,996,151          1   1,996,151.0   1,996,151.0   1,996,151   1,996,151           0.0  PushPop  iteration8            
      1.3        1,980,575          1   1,980,575.0   1,980,575.0   1,980,575   1,980,575           0.0  PushPop  iteration6            
      1.3        1,962,207          1   1,962,207.0   1,962,207.0   1,962,207   1,962,207           0.0  PushPop  iteration7            
      1.3        1,912,620          1   1,912,620.0   1,912,620.0   1,912,620   1,912,620           0.0  PushPop  iteration2            
      1.2        1,853,952          1   1,853,952.0   1,853,952.0   1,853,952   1,853,952           0.0  PushPop  iteration3            
      1.2        1,853,393          1   1,853,393.0   1,853,393.0   1,853,393   1,853,393           0.0  PushPop  iteration4            
      0.2          372,886          1     372,886.0     372,886.0     372,886     372,886           0.0  PushPop  cuBLAS:cublasCreate_v2

[4/8] Executing 'osrtsum' stats report

Operating System Runtime API Statistics:

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)       Med (ns)      Min (ns)     Max (ns)    StdDev (ns)            Name         
 --------  ---------------  ---------  -------------  -------------  -----------  -----------  ------------  ----------------------
     57.3      100,159,523          1  100,159,523.0  100,159,523.0  100,159,523  100,159,523           0.0  poll                  
     42.3       73,876,075         20    3,693,803.8    1,008,796.0      359,825   55,493,554  12,193,993.9  pthread_cond_wait     
      0.3          467,456         47        9,945.9        8,102.0        1,397       51,893      11,392.3  ioctl                 
      0.1          216,998        126        1,722.2        1,466.0        1,396        4,609         676.1  pthread_cond_signal   
      0.0           50,565          1       50,565.0       50,565.0       50,565       50,565           0.0  pthread_create        
      0.0            7,053          5        1,410.6        1,397.0        1,396        1,466          31.0  pthread_cond_broadcast

[5/8] Executing 'cudaapisum' stats report

CUDA API Statistics:

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)    Med (ns)   Min (ns)   Max (ns)   StdDev (ns)               Name            
 --------  ---------------  ---------  ------------  ---------  --------  ----------  ------------  ----------------------------
     47.2       43,227,017          3  14,409,005.7    6,216.0     4,679  43,216,122  24,947,694.6  cudaFree                    
     47.2       43,218,915          3  14,406,305.0    3,352.0     2,305  43,213,258  24,947,553.1  cudaFree                    
      1.6        1,485,675         98      15,159.9   12,572.0     9,359      81,994      10,043.0  cudaLaunchKernel            
      1.4        1,262,812         98      12,885.8   10,406.0     6,705      79,131      10,035.4  cudaLaunchKernel            
      0.6          513,608        378       1,358.8    1,397.0       908       2,794         234.9  cuGetProcAddress            
      0.4          383,848         31      12,382.2   11,105.0     9,568      29,613       4,002.7  cudaMemsetAsync             
      0.3          313,377         31      10,108.9    9,079.0     7,124      26,820       3,967.1  cudaMemsetAsync             
      0.3          287,821         16      17,988.8    6,251.0     5,587     168,458      40,584.6  cudaStreamCreateWithFlags   
      0.2          186,688          1     186,688.0  186,688.0   186,688     186,688           0.0  cuCtxSynchronize            
      0.2          146,109          4      36,527.3   13,549.0     8,451     110,560      49,414.5  cudaMalloc                  
      0.2          143,805         20       7,190.3    7,019.0     5,239      13,969       1,734.7  cudaEventRecord             
      0.1          136,682          4      34,170.5   11,175.0     6,147     108,185      49,400.7  cudaMalloc                  
      0.1           98,685         20       4,934.3    4,714.0     3,283      12,083       1,817.2  cudaEventRecord             
      0.1           56,989         19       2,999.4    1,956.0     1,397       9,358       1,918.5  cudaEventCreateWithFlags    
      0.1           49,448         20       2,472.4    2,374.5     1,885       3,353         397.0  cudaStreamIsCapturing_v10000
      0.0           22,839          1      22,839.0   22,839.0    22,839      22,839           0.0  cudaHostAlloc               
      0.0           20,883          1      20,883.0   20,883.0    20,883      20,883           0.0  cudaHostAlloc               
      0.0            2,445          1       2,445.0    2,445.0     2,445       2,445           0.0  cuInit                      
      0.0            1,886          1       1,886.0    1,886.0     1,886       1,886           0.0  cuMemHostGetDevicePointer_v2
      0.0            1,397          1       1,397.0    1,397.0     1,397       1,397           0.0  cuModuleGetLoadingMode      

[6/8] Executing 'gpukernsum' stats report

CUDA Kernel Statistics:

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     63.2        5,664,331         10  566,433.1  566,071.5   565,752   567,672        644.0  void wgrad_alg1_nd_float_engine<float, float, (int)3, (int)0, (int)5, (int)7, (int)4, (int)3, (int)…
     11.0          981,936         10   98,193.6   98,079.0    97,694    98,815        325.8  void implicit_convolveNd_sgemm<float, (int)3, (int)128, (int)5, (int)5, (int)3, (int)3, (int)3, (in…
      9.4          839,126         10   83,912.6   83,999.0    83,167    84,351        394.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::…
      5.8          519,286         20   25,964.3   25,920.0     3,584    48,958     22,857.6  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::func_wrapp…
      5.5          494,486         10   49,448.6   49,454.5    49,312    49,599         87.7  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
      4.4          396,921         10   39,692.1   39,695.0    39,456    40,223        214.9  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::…
      0.5           47,294         18    2,627.4    2,655.5     2,528     2,688         58.9  void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctor_add<float>, at::deta…
      0.2           21,790         10    2,179.0    2,176.0     2,143     2,240         35.2  void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::…

[7/8] Executing 'gpumemtimesum' stats report

CUDA Memory Operation Statistics (by time):

 Time (%)  Total Time (ns)  Count  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)    Operation  
 --------  ---------------  -----  --------  --------  --------  --------  -----------  -------------
    100.0           30,336     31     978.6     928.0       864     2,592        305.1  [CUDA memset]

[8/8] Executing 'gpumemsizesum' stats report

CUDA Memory Operation Statistics (by size):

 Total (MB)  Count  Avg (MB)  Med (MB)  Min (MB)  Max (MB)  StdDev (MB)    Operation  
 ----------  -----  --------  --------  --------  --------  -----------  -------------
      0.014     31     0.000     0.000     0.000     0.013        0.002  [CUDA memset]

For allow_tf32 to true:

NVTX Range Statistics:

 Time (%)  Total Time (ns)  Instances    Avg (ns)      Med (ns)     Min (ns)    Max (ns)   StdDev (ns)    Style           Range         
 --------  ---------------  ---------  ------------  ------------  ----------  ----------  ------------  -------  ----------------------
     42.7       60,504,152         10   6,050,415.2   1,258,376.0     996,015  49,626,209  15,311,893.5  PushPop  backward              
     36.3       51,315,196          1  51,315,196.0  51,315,196.0  51,315,196  51,315,196           0.0  PushPop  iteration0            
      6.7        9,430,330         10     943,033.0     889,925.0     839,709   1,419,396     170,547.6  PushPop  forward               
      1.7        2,414,154          1   2,414,154.0   2,414,154.0   2,414,154   2,414,154           0.0  PushPop  iteration8            
      1.7        2,390,547          1   2,390,547.0   2,390,547.0   2,390,547   2,390,547           0.0  PushPop  iteration7            
      1.7        2,372,878          1   2,372,878.0   2,372,878.0   2,372,878   2,372,878           0.0  PushPop  iteration9            
      1.7        2,364,427          1   2,364,427.0   2,364,427.0   2,364,427   2,364,427           0.0  PushPop  iteration1            
      1.6        2,276,705          1   2,276,705.0   2,276,705.0   2,276,705   2,276,705           0.0  PushPop  iteration2            
      1.5        2,136,254          1   2,136,254.0   2,136,254.0   2,136,254   2,136,254           0.0  PushPop  iteration6            
      1.5        2,095,326          1   2,095,326.0   2,095,326.0   2,095,326   2,095,326           0.0  PushPop  iteration4            
      1.4        2,025,274          1   2,025,274.0   2,025,274.0   2,025,274   2,025,274           0.0  PushPop  iteration3            
      1.4        2,018,988          1   2,018,988.0   2,018,988.0   2,018,988   2,018,988           0.0  PushPop  iteration5            
      0.1          209,595          1     209,595.0     209,595.0     209,595     209,595           0.0  PushPop  cuBLAS:cublasCreate_v2

[4/8] Executing 'osrtsum' stats report

Operating System Runtime API Statistics:

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)       Med (ns)      Min (ns)     Max (ns)    StdDev (ns)            Name         
 --------  ---------------  ---------  -------------  -------------  -----------  -----------  ------------  ----------------------
     58.8      100,167,418          1  100,167,418.0  100,167,418.0  100,167,418  100,167,418           0.0  poll                  
     40.9       69,654,065         20    3,482,703.3    1,205,610.0      238,301   48,951,328  10,705,522.6  pthread_cond_wait     
      0.2          336,779         45        7,484.0        6,146.0        1,396       45,188       8,683.1  ioctl                 
      0.1          207,222        117        1,771.1        1,466.0        1,396        6,635         815.2  pthread_cond_signal   
      0.0           81,016          1       81,016.0       81,016.0       81,016       81,016           0.0  pthread_create        
      0.0            8,661          6        1,443.5        1,466.5        1,397        1,467          36.0  pthread_cond_broadcast

[5/8] Executing 'cudaapisum' stats report

CUDA API Statistics:

 Time (%)  Total Time (ns)  Num Calls    Avg (ns)    Med (ns)   Min (ns)   Max (ns)   StdDev (ns)               Name            
 --------  ---------------  ---------  ------------  ---------  --------  ----------  ------------  ----------------------------
     46.5       39,396,681          3  13,132,227.0    5,169.0     3,841  39,387,671  22,737,881.5  cudaFree                    
     46.5       39,389,556          3  13,129,852.0    2,794.0     1,466  39,385,296  22,737,881.5  cudaFree                    
      2.4        2,050,982        158      12,980.9   11,384.5     7,054      89,677       8,886.0  cudaLaunchKernel            
      2.0        1,687,170        158      10,678.3    9,149.0     5,098      87,232       8,881.6  cudaLaunchKernel            
      0.5          459,903        378       1,216.7    1,396.5       907       2,374         251.7  cuGetProcAddress            
      0.4          344,948         31      11,127.4   10,616.0     7,683      20,673       2,454.4  cudaMemsetAsync             
      0.3          293,475          1     293,475.0  293,475.0   293,475     293,475           0.0  cuCtxSynchronize            
      0.3          275,873         31       8,899.1    8,730.0     5,657      18,717       2,443.7  cudaMemsetAsync             
      0.3          215,670         16      13,479.4    4,574.5     4,191     130,534      31,442.4  cudaStreamCreateWithFlags   
      0.2          135,075         20       6,753.8    6,740.0     4,680      10,337       1,399.0  cudaEventRecord             
      0.1           93,587          4      23,396.8   11,244.0     5,727      65,372      28,221.0  cudaMalloc                  
      0.1           88,631         20       4,431.6    4,470.0     2,374       8,451       1,392.4  cudaEventRecord             
      0.1           84,300          4      21,075.0    9,044.5     2,794      63,417      28,479.9  cudaMalloc                  
      0.1           46,653         20       2,332.7    2,374.0     1,397       3,352         420.9  cudaStreamIsCapturing_v10000
      0.0           30,732         19       1,617.5    1,467.0       908       2,794         445.0  cudaEventCreateWithFlags    
      0.0           20,464          1      20,464.0   20,464.0    20,464      20,464           0.0  cudaHostAlloc               
      0.0           18,089          1      18,089.0   18,089.0    18,089      18,089           0.0  cudaHostAlloc               
      0.0            1,885          1       1,885.0    1,885.0     1,885       1,885           0.0  cuInit                      
      0.0            1,467          1       1,467.0    1,467.0     1,467       1,467           0.0  cuMemHostGetDevicePointer_v2
      0.0            1,397          1       1,397.0    1,397.0     1,397       1,397           0.0  cuModuleGetLoadingMode      

[6/8] Executing 'gpukernsum' stats report

CUDA Kernel Statistics:

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     57.8        7,203,562         10  720,356.2  719,492.5   718,356   725,172      2,262.6  sm80_xmma_wgrad_implicit_gemm_indexed_wo_smem_tf32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize32x16x64_st…
     11.2        1,400,777         40   35,019.4   26,239.0     2,783    85,503     30,729.7  void cudnn::ops::nchwToNhwcKernel<float, float, float, (bool)0, (bool)1, (cudnnKernelDataType_t)2>(…
      7.1          883,827         20   44,191.4   44,176.0     3,168    85,534     41,990.8  void cudnn::ops::nhwcToNchwKernel<float, float, float, (bool)1, (bool)0, (cudnnKernelDataType_t)0>(…
      6.8          840,850         10   84,085.0   84,190.0    83,519    84,511        375.7  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::…
      5.1          635,734         10   63,573.4   63,439.0    63,295    63,935        265.6  sm80_xmma_fprop_implicit_gemm_indexed_wo_smem_tf32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize128x16x32_s…
      4.2          528,695         20   26,434.8   26,352.0     3,680    49,504     23,225.9  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::func_wrapp…
      4.0          493,495         10   49,349.5   49,359.0    49,215    49,439         70.4  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::MeanOps<fl…
      3.2          395,993         10   39,599.3   39,583.5    39,423    39,807        113.1  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl<at::native::…
      0.4           47,294         18    2,627.4    2,624.0     2,559     2,784         70.2  void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctor_add<float>, at::deta…
      0.2           22,303         10    2,230.3    2,240.0     2,207     2,240         15.6  void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<float>, at::detail::…

[7/8] Executing 'gpumemtimesum' stats report

CUDA Memory Operation Statistics (by time):

 Time (%)  Total Time (ns)  Count  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)    Operation  
 --------  ---------------  -----  --------  --------  --------  --------  -----------  -------------
    100.0           30,304     31     977.5     928.0       896     2,592        306.1  [CUDA memset]

[8/8] Executing 'gpumemsizesum' stats report

CUDA Memory Operation Statistics (by size):

 Total (MB)  Count  Avg (MB)  Med (MB)  Min (MB)  Max (MB)  StdDev (MB)    Operation  
 ----------  -----  --------  --------  --------  --------  -----------  -------------
      0.015     31     0.000     0.000     0.000     0.013        0.002  [CUDA memset]

The TF32 kernel itself is faster as you can see in sm80_xmma_fprop_implicit_gemm_indexed_wo_smem_tf32f32_tf32f32_f32_nhwckrsc_nhwc_tilesize128x16x32_s while the additional transposes might cause the slowdown, so you could check if using the channels_last memory layout would help.

Is there an overhead for using tf_32? How to properly compare the timing for two different models? Can I do it in my way?