I recently discovered torch.einsum, and it seem that it can be useful is parallelizing linear operations on the GPU since it can prevent summation across certain dimensions. It seems others have the same idea as I see if in popular open source pytorch code such as hf transformers/modeling_bert.py at main · huggingface/transformers · GitHub
I wrote some thoughts about it here
pytorch
But in short, it seems that it can be a powerful tool to parallelize separate linear operations on the GPU. But it seems that this might be too good to be true, as discussed in this closed issue
opened 08:23PM - 24 Jan 20 UTC
closed 04:57PM - 08 Jun 21 UTC
module: performance
module: cuda
triaged
module: linear algebra
## 🐛 Bug
A manual multiplication and summation `(a * b).sum(dim = (-3, -2, -1))… ` is about 20X faster than the equivalent `einsum`.
## To Reproduce
```Python3
a = torch.rand(1, 1, 16, 2, 16, 2, 16, 2, 2, 2, 2, device = "cuda")
b = torch.rand(729, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, device = "cuda")
# Warmup
for i in range(100): output1 = (a * b).sum(dim = (-3, -2, -1))
# Method 1
with torch.autograd.profiler.profile(use_cuda = True) as prof:
output1 = (a * b).sum(dim = (-3, -2, -1))
print(prof.key_averages().table(sort_by="cuda_time_total"))
# Warmup
for i in range(100):
(a2, b2) = torch.broadcast_tensors(a, b)
output2 = torch.einsum("...ijk, ...ijk -> ...", a2, b2)
# Method 2
with torch.autograd.profiler.profile(use_cuda = True) as prof:
(a2, b2) = torch.broadcast_tensors(a, b)
output2 = torch.einsum("...ijk, ...ijk -> ...", a2, b2)
print(prof.key_averages().table(sort_by="cuda_time_total"))
```
```
------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls
------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
sum 28.93% 43.848us 28.93% 43.848us 43.848us 65.81% 8.978ms 8.978ms 1
mul 71.07% 107.713us 71.07% 107.713us 107.713us 34.19% 4.665ms 4.665ms 1
------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
Self CPU time total: 151.561us
CUDA time total: 13.644ms
--------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg CUDA total % CUDA total CUDA time avg Number of Calls
--------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
einsum 4.05% 59.449us 97.88% 1.435ms 1.435ms 48.62% 123.798ms 123.798ms 1
bmm 85.15% 1.248ms 85.15% 1.248ms 1.248ms 45.83% 116.694ms 116.694ms 1
reshape 1.35% 19.770us 6.36% 93.237us 46.619us 2.76% 7.031ms 3.516ms 2
clone 3.96% 58.021us 3.96% 58.021us 29.011us 2.75% 7.012ms 3.506ms 2
broadcast_tensors 0.79% 11.624us 2.12% 31.095us 31.095us 0.01% 31.008us 31.008us 1
expand 1.33% 19.471us 1.33% 19.471us 9.736us 0.01% 20.480us 10.240us 2
permute 1.54% 22.535us 1.54% 22.535us 4.507us 0.01% 19.327us 3.865us 5
view 0.78% 11.442us 0.78% 11.442us 5.721us 0.00% 4.094us 2.047us 2
_unsafe_view 1.05% 15.446us 1.05% 15.446us 7.723us 0.00% 3.968us 1.984us 2
--------------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- --------------- ---------------
Self CPU time total: 1.466ms
CUDA time total: 254.614ms
```
## Expected behavior
`einsum` to be no slower than manual. I actually expected / hoped for it to be faster as it avoids creating the large intermediary result of the multiplication.
## Environment
The environment is Google Colab with a fresh GPU instance on which `!pip install -q torch==1.4.0 torchvision==0.5.0` has been run as Colab is still on `1.3.1` by default.
```
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.12.0
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: Tesla P100-PCIE-16GB
Nvidia driver version: 418.67
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip3] numpy==1.17.5
[pip3] torch==1.4.0
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.5.0
[conda] Could not collect
```
cc @ngimel @vincentqb @vishwakftw @jianyuh @nikitaved @pearu @VitalyFedyunin @mruberry
They fixed the major slow down issues, but from my limited comprehension of the discussion, it seems that there are some inherent slowdowns when using einsum.
If this is true, is there a scale where einsum becomes desired? Is there a more ideal method for parallelizing separate linear operations?
eqy
May 30, 2022, 9:40pm
2
The limitations for einsum are likely due to the limited scope of the underlying kernels and strategies that are implemented for it. You might get some better results e.g., if your computation maps more directly onto something like bmm torch.bmm — PyTorch 1.11.0 documentation rather than expressing it via an einsum.