I recently discovered torch.einsum, and it seem that it can be useful is parallelizing linear operations on the GPU since it can prevent summation across certain dimensions. It seems others have the same idea as I see if in popular open source pytorch code such as hf transformers/modeling_bert.py at main · huggingface/transformers · GitHub
I wrote some thoughts about it here
But in short, it seems that it can be a powerful tool to parallelize separate linear operations on the GPU. But it seems that this might be too good to be true, as discussed in this closed issue
They fixed the major slow down issues, but from my limited comprehension of the discussion, it seems that there are some inherent slowdowns when using einsum.
If this is true, is there a scale where einsum becomes desired? Is there a more ideal method for parallelizing separate linear operations?