Parallelizing many separate linear operations, is torch.einsum a viable option?

I recently discovered torch.einsum, and it seem that it can be useful is parallelizing linear operations on the GPU since it can prevent summation across certain dimensions. It seems others have the same idea as I see if in popular open source pytorch code such as hf transformers/ at main · huggingface/transformers · GitHub

I wrote some thoughts about it here

But in short, it seems that it can be a powerful tool to parallelize separate linear operations on the GPU. But it seems that this might be too good to be true, as discussed in this closed issue

They fixed the major slow down issues, but from my limited comprehension of the discussion, it seems that there are some inherent slowdowns when using einsum.

If this is true, is there a scale where einsum becomes desired? Is there a more ideal method for parallelizing separate linear operations?

The limitations for einsum are likely due to the limited scope of the underlying kernels and strategies that are implemented for it. You might get some better results e.g., if your computation maps more directly onto something like bmm torch.bmm — PyTorch 1.11.0 documentation rather than expressing it via an einsum.