Einsum is not fast?

I test it like this:

import torch
n_times = 10000
bs = 20000
i, j, k = 3000, 40000, 50000
a = torch.rand(2000, 3, 4).double()
b = torch.rand(2000, 4, 5).double()
c = torch.rand(2000, 5, 6).double()

t1 = time.time()
for _ in range(n_times):
    r1 = (a @ b) @ c
torch.cuda.synchronize()

t2 = time.time()
for _ in range(n_times):
    r2 = a @ (b @ c)
torch.cuda.synchronize()

t3 = time.time()
for _ in range(n_times):
    r3 = torch.einsum('abc,acd,ade->abe', a, b, c)
torch.cuda.synchronize()

t4 = time.time()

print((r1 - r2).abs().max())
print((r2 - r3).abs().max())


print((t2 - t1))
print((t3 - t2))
print((t4 - t3))

The time of the 3 are 6s, 1s, 2s.

It seems that torch.einsum does not choose the best computing method according to the input. Is this the expected behavior?

Note that the first timing is wrong as you are also accumulating the tensor creation due to the lack of synchronization.
I get quite similar times for all workloads, but einsum might indeed add a small overhead:

0.47559189796447754
0.4516446590423584
0.4858055114746094