For two simple einsum cases, Case 2 is running ~30% slower than Case 1, despite needing to perform 101x fewer ops.
Case 1
a = torch.randn(10000, 100, 101).cuda()
b = torch.randn(10000, 101, 3).cuda()
%%time
for _ in range(10000):
torch.einsum("bij, bjf->bif", a, b)
torch.cuda.synchronize()
CPU times: user 26.7 s, sys: 18.8 s, total: 45.5 s Wall time: 45.5 s
Case 2
c = torch.randn(10000, 100, 1).cuda()
d = torch.randn(10000, 100, 1, 3).cuda()
%%time
for _ in range(1000):
torch.einsum("bic, bicf->bif", c, d)
torch.cuda.synchronize()
CPU times: user 37.7 s, sys: 23.7 s, total: 1min 1s Wall time: 1min 1s
I believe the number of ops needed for each case should be approximately:
Case 1: 10,000 x 100 x 101 x 3 ( x 2)
Case 2: 10,000 x 100 x 1 x 3 (x 2)
(The factor of 2 is to account for the multiply-accumulate operation used in matrix multiplication)
Despite requiring 101x fewer ops and operating over smaller tensors than Case 1, in the example above Case 2 has a wall clock time 34% greater than Case 1.
Are there methods/kernels that can be used which are more optimised towards the later case?