I had the following code snippet for my project and I noticed a substantial difference in both speed and memory when I altered between einsum and matmul:
import torch
import time
bs = 8
L = 2048
dim = 64
tensor1 = torch.randn((bs, L, dim)).to('cuda')
tensor2 = torch.randn((L, L, dim)).to('cuda')
# warmup the GPU
for _ in range(5):
warump_tensor = torch.matmul(tensor1, tensor1.transpose(1, 2))
torch.cuda.synchronize()
start = time.time()
output1 = torch.einsum("bld,lrd->blr", tensor1, tensor2)
torch.cuda.synchronize()
end = time.time()
print('einsum time:', end-start)
print('einsum memory (GB):', torch.cuda.max_memory_allocated('cuda')/10**9)
torch.cuda.synchronize()
start = time.time()
output2 = torch.matmul(tensor2, tensor1.unsqueeze(-1)).squeeze(-1)
torch.cuda.synchronize()
end = time.time()
print('matmul time:', end-start)
print('matmul memory (GB):', torch.cuda.max_memory_allocated('cuda')/10**9)
print('same res?', torch.allclose(output1, output2, atol=1e-5)) # we are using float not double
Running the above code gives the following on a NVIDIA A6000 GPU:
Is this normal? I suppose matmul should be as fast and memory efficient as einsum. If that’s not the case, is there anyway to know what happened under the hood of einsum?
I don’t know if its “normal,” but this kind of thing has been seen before.
See, for example:
It might be worth noting that because you are adding a trailing singleton
dimension (unsqueeze (-1)) to tensor1, you are, in effect, performing
a batch of vector dot products rather than a batch of fully general matrix
products.
Computing a batch of dot products is not a rare use case, but pytorch
does not offer a specialized batch-dot-product function. I’ve come to
conclude that einsum() is a perfectly satisfactory way to compute a
batch-dot-product (and it’s what I use by default when the need arises).
(It’s worth noting that there are instances where einsum() – perhaps with
older versions of pytorch – unreasonably underperforms the equivalent matmul() computation (with various transpose()s and unsqueeze()s
to get the dimensions to line up correctly).)
Idle speculation:
Perhaps matmul()'s performance tuning has been focused on full matrix
products, rather than the “edge” case of batch dot products. This would
hardly excuse matmul()'s underperformance, but might offer a historical
explanation.
Or it might be some glitch in matmul()'s broadcasting support. It might
be interesting to perform the comparison when creating tensor1 with
an explicit trailing singleton dimension, rather than using unsqueeze().
(You could also try adding a leading singleton dimension to tensor2.
You would, of course, still be broadcasting bs over tensor2’s singleton
dimension and I don’t think it would be a fair comparison to avoid such
broadcasting.)