a = [torch.randn(torch.randint(2,300, (2,)).tolist()).cuda() for i in range(20000)]

b = [x.T.clone() for x in a]

torch.cuda.synchronize()

take a list of tensors of varying shape, (shape is irrelevant, even for the same shape the problem persists)

c = torch.cat([(a[i] @ b[i]).flatten() for i in range(len(a))])

torch.cuda.synchronize()

performing matrix multiplication as following, results is 10x slower result using any pytorch 2.x verson when compared with any 1.xx version.

Average run time on same hardware with 1.xx version is 0.58s wheras for 2.xx version is

5.56s.

I did try setting

torch.set_float32_matmul_precision(‘high’)

torch.backends.cuda.matmul.allow_tf32 = True

torch.backends.cudnn.allow_tf32 = True

None of them provide any improvement

Here is the profiler for a smaller example where we observe CPU usage of 1s compared against the 0.2s of 1.xx version, though the CUDA times are amost the same. All test were done on a Tesla V100. Additionally, the memory usage is much higher too for repeated runs!!

```
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
```

```
model_inference 30.31% 214.848ms 99.98% 708.780ms 708.780ms 0.000us 0.00% 154.325ms 154.325ms 1
aten::matmul 4.01% 28.425ms 56.60% 401.255ms 20.063us 0.000us 0.00% 148.566ms 7.428us 20000
aten::mm 34.94% 247.695ms 52.63% 373.070ms 18.654us 117.146ms 95.40% 148.682ms 7.434us 20000
cudaLaunchKernel 14.71% 104.315ms 14.71% 104.315ms 5.175us 11.601ms 9.45% 11.692ms 0.580us 20157
```

```
model_inference 11.89% 183.206ms 100.00% 1.540s 1.540s 0.000us 0.00% 117.757ms 117.757ms 1
aten::matmul 2.46% 37.840ms 81.24% 1.251s 62.571us 0.000us 0.00% 116.147ms 5.807us 20000
aten::mm 67.93% 1.046s 78.79% 1.214s 60.679us 116.147ms 7.20% 116.147ms 5.807us 20000
cudaLaunchKernel 9.00% 138.639ms 9.00% 138.639ms 6.878us 0.000us 0.00% 0.000us 0.000us 20157
```