I test it like this:
import torch
n_times = 10000
bs = 20000
i, j, k = 3000, 40000, 50000
a = torch.rand(2000, 3, 4).double()
b = torch.rand(2000, 4, 5).double()
c = torch.rand(2000, 5, 6).double()
t1 = time.time()
for _ in range(n_times):
r1 = (a @ b) @ c
torch.cuda.synchronize()
t2 = time.time()
for _ in range(n_times):
r2 = a @ (b @ c)
torch.cuda.synchronize()
t3 = time.time()
for _ in range(n_times):
r3 = torch.einsum('abc,acd,ade->abe', a, b, c)
torch.cuda.synchronize()
t4 = time.time()
print((r1 - r2).abs().max())
print((r2 - r3).abs().max())
print((t2 - t1))
print((t3 - t2))
print((t4 - t3))
The time of the 3 are 6s, 1s, 2s.
It seems that torch.einsum does not choose the best computing method according to the input. Is this the expected behavior?