Is there any method to estmate the time comsuming for torch.tensordot() or einsum() on GPU?

As we do some large dimension tensors contraction, we found there are two kinds of kernels are used, unrolled_elementwise_kernel for tensors reshape, and cutlass::Kernel for tensors mm.

For the first step (reshape), how to calculate the data size copied and estimate the data copy (device to device)?

For the second (mm), there always several or many cuda kernels generated for larger size tensors, could anybody know how that comes from?

Thanks.

The reshape will copy the tensor, so it likely you will find an approximately linear relationship between .numel() and runtime. There are some subtleties as the input is going to be non-contiguous (otherwise the .reshape would not copy), but maybe those do not change as much between the things you would be comparing.
I think mm on CUDA uses CuBLAS by default. I think what you are seeing are optimizations CuBLAS makes.

Best regards

Thomas

Thanks for your reply @tom. I have more test recently. For the first reshape step, we found the seleted dimensions to be contracted is important to this step.

For example, I want to select 6 dimensions from a 31-dim tensor to permute and reshape, if the dimensions (22, 24, 26, 18, 21, 0) cost ~48ms, but (29, 0, 8, 25, 26, 27) cost ~25ms (At the end I paste my test code).

I profile to see the slower case is memory bound and include more uncoalesced global memory accessses. But so far I cannot make it clear how to calculate the memory requirement.

import torch
import time

a = torch.randn([2]*31, device="cuda", dtype=torch.complex64)

for i in range(20):
    torch.cuda.synchronize()
    t0 = time.time()
    #a.permute((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 23, 25, 27, 28, 29, 30, 22, 24, 26, 18, 21, 0)).reshape([2**25, 2**6])
    a.permute((1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 28, 30, 29, 8, 25, 26, 27, 0)).reshape([2**25, 2**6])
    torch.cuda.synchronize()
    t1 = time.time()
    print(f'Iter#{i} time/sec: ', t1 - t0 )

So einsum translates to a series of batch matrix multiplications after permutations and merging the axes that are treated the same (contraction/lhs to result/rhs to result/batch). So you could do this yourself if you wanted to.