GPU speed and memory difference between einsum and matmul

Hi,

I had the following code snippet for my project and I noticed a substantial difference in both speed and memory when I altered between einsum and matmul:

import torch
import time

bs = 8
L = 2048
dim = 64

tensor1 = torch.randn((bs, L, dim)).to('cuda')
tensor2 = torch.randn((L, L, dim)).to('cuda')

# warmup the GPU
for _ in range(5):
    warump_tensor = torch.matmul(tensor1, tensor1.transpose(1, 2))

torch.cuda.synchronize()
start = time.time()
output1 = torch.einsum("bld,lrd->blr", tensor1, tensor2)
torch.cuda.synchronize()
end = time.time()
print('einsum time:', end-start)
print('einsum memory (GB):', torch.cuda.max_memory_allocated('cuda')/10**9)

torch.cuda.synchronize()
start = time.time()
output2 = torch.matmul(tensor2, tensor1.unsqueeze(-1)).squeeze(-1)
torch.cuda.synchronize()
end = time.time()
print('matmul time:', end-start)
print('matmul memory (GB):', torch.cuda.max_memory_allocated('cuda')/10**9)

print('same res?', torch.allclose(output1, output2, atol=1e-5)) # we are using float not double

Running the above code gives the following on a NVIDIA A6000 GPU:

einsum time: 0.0035064220428466797
einsum memory (GB): 1.346371584
matmul time: 0.04485011100769043
matmul memory (GB): 10.070523904
same res? True

Is this normal? I suppose matmul should be as fast and memory efficient as einsum. If that’s not the case, is there anyway to know what happened under the hood of einsum?

Thanks!

Hi Jason!

I can reproduce your observation (on a smaller gpu).

Here is my tweaked version of your test script:

import torch
import time

print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())
print (torch.cuda.get_device_properties ('cuda').total_memory)

_ = torch.manual_seed (2022)

bs = 8
L = 1024   # reduce size to fit in smaller gpu
dim = 64

tensor1 = torch.randn((bs, L, dim)).to('cuda')
tensor2 = torch.randn((L, L, dim)).to('cuda')

# warmup the GPU -- use actual tensors and operations
for _ in range(5):
    warump_tensor = torch.matmul(tensor1, tensor1.transpose(1, 2))
    warmup_tensor = None
    warmup_tensor = torch.einsum("bld,lrd->blr", tensor1, tensor2)
    warmup_tensor = None
    warmup_tensor = torch.matmul(tensor2, tensor1.unsqueeze(-1)).squeeze(-1)
    warmup_tensor = None

torch.cuda.reset_peak_memory_stats ('cuda')
torch.cuda.synchronize()
start = time.time()
output1 = torch.einsum("bld,lrd->blr", tensor1, tensor2)
torch.cuda.synchronize()
end = time.time()
print('einsum time:', end-start)
print('einsum memory (GB):', torch.cuda.max_memory_allocated('cuda')/10**9)

output1 = None

torch.cuda.reset_peak_memory_stats ('cuda')
torch.cuda.synchronize()
start = time.time()
output2 = torch.matmul(tensor2, tensor1.unsqueeze(-1)).squeeze(-1)
torch.cuda.synchronize()
end = time.time()
print('matmul time:', end-start)
print('matmul memory (GB):', torch.cuda.max_memory_allocated('cuda')/10**9)

output1 = torch.einsum("bld,lrd->blr", tensor1, tensor2)   # recompute einsum result for allclose() check

print('same res?', torch.allclose(output1, output2, atol=1e-5)) # we are using float not double

And here is it’s output:

1.12.0
11.6
GeForce GTX 1050 Ti
4236312576
einsum time: 0.008707761764526367
einsum memory (GB): 0.337641472
matmul time: 0.07655215263366699
matmul memory (GB): 2.48512512
same res? True

I don’t know if its “normal,” but this kind of thing has been seen before.
See, for example:

It might be worth noting that because you are adding a trailing singleton
dimension (unsqueeze (-1)) to tensor1, you are, in effect, performing
a batch of vector dot products rather than a batch of fully general matrix
products.

Computing a batch of dot products is not a rare use case, but pytorch
does not offer a specialized batch-dot-product function. I’ve come to
conclude that einsum() is a perfectly satisfactory way to compute a
batch-dot-product (and it’s what I use by default when the need arises).

(It’s worth noting that there are instances where einsum() – perhaps with
older versions of pytorch – unreasonably underperforms the equivalent
matmul() computation (with various transpose()s and unsqueeze()s
to get the dimensions to line up correctly).)

Idle speculation:

Perhaps matmul()'s performance tuning has been focused on full matrix
products, rather than the “edge” case of batch dot products. This would
hardly excuse matmul()'s underperformance, but might offer a historical
explanation.

Or it might be some glitch in matmul()'s broadcasting support. It might
be interesting to perform the comparison when creating tensor1 with
an explicit trailing singleton dimension, rather than using unsqueeze().
(You could also try adding a leading singleton dimension to tensor2.
You would, of course, still be broadcasting bs over tensor2’s singleton
dimension and I don’t think it would be a fair comparison to avoid such
broadcasting.)

Best.

K. Frank