Torch Execution Time Differences: 1.13.0 vs. 2.4.0

I recently changed to PyTorch 2.4.0 while having used PyTorch 1.13.0 for quite some time. As PyTorch 2.0.0 came with many new features, I was interested in trying them out and of course also hoped for some execution time reductions.

However, after testing I found that for my specific usage PyTorch version 2.4.0 is up to 50% slower than PyTorch version 1.13.0 using the same code.

I tried to come up with a minimal code example that reproduces part of that issue. Have a look at the following:

import time

import torch

torch.manual_seed(0)
device = 'cuda:0'
start = 3110
end = 3115
step = 1
for n_mats in range(start, end, step):
    a = [torch.randn(*torch.randint(2, 300, (2,))).to(device).to(torch.float32) for _ in range(n_mats)]
    torch.cuda.synchronize()

    start_time = time.perf_counter()
    c = [(aa @ aa.T) for aa in a]
    torch.cuda.synchronize()

    print(f"Elapsed time: {time.perf_counter() - start_time:.4f} seconds, n_mats: {n_mats}")

The results I get are the following:
PyTorch version 1.13.0

Elapsed time: 0.4380 seconds, n_mats: 3110
Elapsed time: 0.0331 seconds, n_mats: 3111
Elapsed time: 0.0288 seconds, n_mats: 3112
Elapsed time: 0.0295 seconds, n_mats: 3113
Elapsed time: 0.0301 seconds, n_mats: 3114

PyTorch version 2.4.0

Elapsed time: 0.0739 seconds, n_mats: 3110
Elapsed time: 0.0439 seconds, n_mats: 3111
Elapsed time: 0.1221 seconds, n_mats: 3112
Elapsed time: 0.3828 seconds, n_mats: 3113
Elapsed time: 0.3765 seconds, n_mats: 3114

Please let me know if I should adjust the way of measuring the runtime should this not be representative.

If we compare these two results, it seems as though for both versions, there is some additional overhead in the first iteration. For torch version 1.13.0, the overhead is quite drastic where for version 2.4.0 it seems to be negligible. However, version 1.13.0 runs quite fast after the first iteration and seems to be maintaining a proportionally increasing runtime based on the number of matrix multiplications. In contrast, version 2.4.0 experiences a bump by 3x for iteration 3 and another bump of 3x for iteration 4.

Can anyone explain what is causing these runtimes?

I’d suggest running multiple tests for each version. I just ran 3x10 runs each and execution times vary in both cases but when averaged across runs are consistent across both versions. Varied between 0,14772 and 0,15440 seconds for pytorch 2.4 and between 0,1542 and 0,1504 seconds for pytorch 1.13 on average for each run.

Maybe I should have made it clearer by executing it for more iterations, e.g. if I increase the number of iterations to 20 by increasing the end of the interval to 3130.

PyTorch 1.13.0:

Elapsed time: 0.4544 seconds, n_mats: 3110
Elapsed time: 0.0353 seconds, n_mats: 3111
Elapsed time: 0.0314 seconds, n_mats: 3112
Elapsed time: 0.0317 seconds, n_mats: 3113
Elapsed time: 0.0322 seconds, n_mats: 3114
Elapsed time: 0.0317 seconds, n_mats: 3115
Elapsed time: 0.0318 seconds, n_mats: 3116
Elapsed time: 0.0321 seconds, n_mats: 3117
Elapsed time: 0.0317 seconds, n_mats: 3118
Elapsed time: 0.0317 seconds, n_mats: 3119
Elapsed time: 0.0344 seconds, n_mats: 3120
Elapsed time: 0.0319 seconds, n_mats: 3121
Elapsed time: 0.0314 seconds, n_mats: 3122
Elapsed time: 0.0324 seconds, n_mats: 3123
Elapsed time: 0.0314 seconds, n_mats: 3124
Elapsed time: 0.0319 seconds, n_mats: 3125
Elapsed time: 0.0322 seconds, n_mats: 3126
Elapsed time: 0.0315 seconds, n_mats: 3127
Elapsed time: 0.0316 seconds, n_mats: 3128
Elapsed time: 0.0320 seconds, n_mats: 3129

Average: 0.053235

PyTorch 2.4.0

Elapsed time: 0.0786 seconds, n_mats: 3110
Elapsed time: 0.0464 seconds, n_mats: 3111
Elapsed time: 0.1282 seconds, n_mats: 3112
Elapsed time: 0.3918 seconds, n_mats: 3113
Elapsed time: 0.3824 seconds, n_mats: 3114
Elapsed time: 0.3855 seconds, n_mats: 3115
Elapsed time: 0.3872 seconds, n_mats: 3116
Elapsed time: 0.3898 seconds, n_mats: 3117
Elapsed time: 0.3903 seconds, n_mats: 3118
Elapsed time: 0.4193 seconds, n_mats: 3119
Elapsed time: 0.3904 seconds, n_mats: 3120
Elapsed time: 0.3947 seconds, n_mats: 3121
Elapsed time: 0.4009 seconds, n_mats: 3122
Elapsed time: 0.3982 seconds, n_mats: 3123
Elapsed time: 0.3984 seconds, n_mats: 3124
Elapsed time: 0.4022 seconds, n_mats: 3125
Elapsed time: 0.4002 seconds, n_mats: 3126
Elapsed time: 0.4047 seconds, n_mats: 3127
Elapsed time: 0.4123 seconds, n_mats: 3128
Elapsed time: 0.4036 seconds, n_mats: 3129

Average: 0.350255 seconds

For me, it looks as though it hits a resource limit at some point for PyTorch 2.4.0 that then leads to a longer runtime always after iteration 3 resp. 4.
I am running it on an RTX 3090 with 24 GB VRAM.

@paulge: Could you clarify how you got these numbers? Did you run it on cuda? It looks as though you ran it on cpu.

Is there some additional information I should provide to make the issue clearer? Or does anyone know about an easy fix for this? Can anyone reproduce these results on cuda?

For comparison, I ran the same code with 20 iterations (interval 3110 to 3130) on a Google Colab instance using a T4 GPU.
PyTorch 2.5.1 (couldn’t figure out an easy way to change this)

Elapsed time: 0.3557 seconds, n_mats: 3110
Elapsed time: 0.1632 seconds, n_mats: 3111
Elapsed time: 0.3568 seconds, n_mats: 3112
Elapsed time: 1.1286 seconds, n_mats: 3113
Elapsed time: 1.1420 seconds, n_mats: 3114
Elapsed time: 1.1086 seconds, n_mats: 3115
Elapsed time: 1.1074 seconds, n_mats: 3116
Elapsed time: 0.9427 seconds, n_mats: 3117
Elapsed time: 0.9220 seconds, n_mats: 3118
Elapsed time: 0.8918 seconds, n_mats: 3119
Elapsed time: 1.0904 seconds, n_mats: 3120
Elapsed time: 0.9226 seconds, n_mats: 3121
Elapsed time: 0.9580 seconds, n_mats: 3122
Elapsed time: 1.1605 seconds, n_mats: 3123
Elapsed time: 1.0975 seconds, n_mats: 3124
Elapsed time: 1.2236 seconds, n_mats: 3125
Elapsed time: 1.2833 seconds, n_mats: 3126
Elapsed time: 1.1478 seconds, n_mats: 3127
Elapsed time: 1.1807 seconds, n_mats: 3128
Elapsed time: 1.1706 seconds, n_mats: 3129

Comparing the 4 first iterations between PyTorch 2.4.0 on RTX 3090 and PyTorch 2.5.1 on Colab’s T4, we see a very similar pattern of how runtime changes. From this, I would assume it to be a general issue then.

I cannot reproduce and issues on my system and also see the same kernels are called in 1.13.0 and 2.5.1:

2.5.1:
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     81.0      132,596,928     19,811   6,693.1   6,400.0     4,064    13,345      1,634.4  ampere_sgemm_32x32_sliced1x4_tn                                                                     
     10.2       16,745,203      3,491   4,796.7   4,352.0     2,976     8,416      1,404.4  ampere_sgemm_32x128_tn                                                                              
      5.2        8,475,944      4,618   1,835.4   1,792.0     1,567     3,360        189.7  void cublasLt::splitKreduce_kernel<(int)32, (int)16, int, float, float, float, (bool)0, float, floa…
      0.7        1,117,039        330   3,385.0   3,392.0     2,272     4,736        558.1  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)8, (int)9, (bool)0, cublasGemv…
      0.6        1,038,948        321   3,236.6   3,168.0     2,112     4,832        559.8  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)6, (int)7, (bool)0, cublasGemv…
      0.6          927,723        230   4,033.6   4,080.0     3,008     4,960        516.6  ampere_sgemm_128x32_tn                                                                              
      0.5          742,245        120   6,185.4   6,288.0     5,312     6,752        391.8  ampere_sgemm_64x32_sliced1x4_tn                                                                     
      0.4          634,146        154   4,117.8   4,016.0     2,368     5,760        808.0  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)14, (int)15, (bool)0, cublasGe…
      0.4          625,412        175   3,573.8   3,583.0     2,272     4,736        622.3  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)10, (int)11, (bool)0, cublasGe…
      0.2          255,104         95   2,685.3   2,752.0     1,984     3,264        364.0  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)4, (int)4, (bool)1, cublasGemv…
      0.1          241,315         81   2,979.2   2,976.0     2,080     3,712        500.3  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)4, (int)4, (bool)0, cublasGemv…
      0.1          204,262         82   2,491.0   2,528.0     1,824     3,073        287.4  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)2, (int)2, (bool)1, cublasGemv…


1.13.0
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name                                                
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     81.0      132,172,261     19,797   6,676.4   6,400.0     4,128    13,280      1,619.8  ampere_sgemm_32x32_sliced1x4_tn                                                                     
     10.3       16,814,960      3,505   4,797.4   4,352.0     3,040     8,448      1,408.2  ampere_sgemm_32x128_tn                                                                              
      5.2        8,432,384      4,617   1,826.4   1,791.0     1,599     3,008        184.7  void splitKreduce_kernel<(int)32, (int)16, int, float, float, float, float, (bool)1, (bool)0, (bool…
      0.7        1,109,697        330   3,362.7   3,360.0     2,208     4,608        523.4  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)8, (int)9, (bool)0, cublasGemv…
      0.6        1,033,066        321   3,218.3   3,200.0     2,112     4,448        538.3  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)6, (int)7, (bool)0, cublasGemv…
      0.6          924,904        230   4,021.3   4,032.0     3,040     5,152        493.2  ampere_sgemm_128x32_tn                                                                              
      0.5          738,529        120   6,154.4   6,144.0     5,216     6,752        345.7  ampere_sgemm_64x32_sliced1x4_tn                                                                     
      0.4          628,675        154   4,082.3   3,984.0     2,432     5,792        783.9  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)14, (int)15, (bool)0, cublasGe…
      0.4          627,811        175   3,587.5   3,616.0     2,304     4,800        597.7  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)10, (int)11, (bool)0, cublasGe…
      0.2          256,485         95   2,699.8   2,720.0     2,048     3,424        368.7  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)4, (int)4, (bool)1, cublasGemv…
      0.1          244,066         81   3,013.2   3,008.0     2,176     3,712        408.2  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)4, (int)4, (bool)0, cublasGemv…
      0.1          207,330         82   2,528.4   2,560.0     1,920     3,232        299.9  void gemmSN_TN_kernel<float, (int)128, (int)16, (int)2, (int)4, (int)2, (int)2, (bool)1, cublasGemv…

Based on your description if looks as if the GPU is throttling, so you might want to lock its clocks for proper profiling.

Could you post your way of profiling? I will try to get the same kind of results.