einsum running 100x slower than expected when controlling for number of ops

For two simple einsum cases, Case 2 is running ~30% slower than Case 1, despite needing to perform 101x fewer ops.

Case 1

a = torch.randn(10000, 100, 101).cuda()
b = torch.randn(10000, 101, 3).cuda()

%%time
for _ in range(10000):
    torch.einsum("bij, bjf->bif", a, b)
    torch.cuda.synchronize()

CPU times: user 26.7 s, sys: 18.8 s, total: 45.5 s Wall time: 45.5 s

Case 2

c = torch.randn(10000, 100, 1).cuda()
d = torch.randn(10000, 100, 1, 3).cuda()

%%time
for _ in range(1000):
    torch.einsum("bic, bicf->bif", c, d)
    torch.cuda.synchronize()

CPU times: user 37.7 s, sys: 23.7 s, total: 1min 1s Wall time: 1min 1s

I believe the number of ops needed for each case should be approximately:
Case 1: 10,000 x 100 x 101 x 3 ( x 2)
Case 2: 10,000 x 100 x 1 x 3 (x 2)
(The factor of 2 is to account for the multiply-accumulate operation used in matrix multiplication)

Despite requiring 101x fewer ops and operating over smaller tensors than Case 1, in the example above Case 2 has a wall clock time 34% greater than Case 1.
Are there methods/kernels that can be used which are more optimised towards the later case?

The problem is that einsum reduces to batch matmul and so copies your data around.
The real solution is to implement a more general contraction. I had a patch using TensorIterators instead a few years ago, but somehow I decided that it would not work on CPU and abandoned it instead of measuring it on GPU. And then it never was enough fun to dust it off.

One easy way would be to try the Keops third party library.

Best regards

Thomas

To not have the same discussion twice:

Thanks for your help with this Thomas!

I’ll play around with Keops and see if I can get something closer to the desired behaviour.

Best wishes,

Sam