I recently changed to PyTorch 2.4.0 while having used PyTorch 1.13.0 for quite some time. As PyTorch 2.0.0 came with many new features, I was interested in trying them out and of course also hoped for some execution time reductions.
However, after testing I found that for my specific usage PyTorch version 2.4.0 is up to 50% slower than PyTorch version 1.13.0 using the same code.
I tried to come up with a minimal code example that reproduces part of that issue. Have a look at the following:
import time
import torch
torch.manual_seed(0)
device = 'cuda:0'
start = 3110
end = 3115
step = 1
for n_mats in range(start, end, step):
a = [torch.randn(*torch.randint(2, 300, (2,))).to(device).to(torch.float32) for _ in range(n_mats)]
torch.cuda.synchronize()
start_time = time.perf_counter()
c = [(aa @ aa.T) for aa in a]
torch.cuda.synchronize()
print(f"Elapsed time: {time.perf_counter() - start_time:.4f} seconds, n_mats: {n_mats}")
The results I get are the following:
PyTorch version 1.13.0
Elapsed time: 0.4380 seconds, n_mats: 3110
Elapsed time: 0.0331 seconds, n_mats: 3111
Elapsed time: 0.0288 seconds, n_mats: 3112
Elapsed time: 0.0295 seconds, n_mats: 3113
Elapsed time: 0.0301 seconds, n_mats: 3114
PyTorch version 2.4.0
Elapsed time: 0.0739 seconds, n_mats: 3110
Elapsed time: 0.0439 seconds, n_mats: 3111
Elapsed time: 0.1221 seconds, n_mats: 3112
Elapsed time: 0.3828 seconds, n_mats: 3113
Elapsed time: 0.3765 seconds, n_mats: 3114
Please let me know if I should adjust the way of measuring the runtime should this not be representative.
If we compare these two results, it seems as though for both versions, there is some additional overhead in the first iteration. For torch version 1.13.0, the overhead is quite drastic where for version 2.4.0 it seems to be negligible. However, version 1.13.0 runs quite fast after the first iteration and seems to be maintaining a proportionally increasing runtime based on the number of matrix multiplications. In contrast, version 2.4.0 experiences a bump by 3x for iteration 3 and another bump of 3x for iteration 4.
Can anyone explain what is causing these runtimes?