My use case is to project the hidden state of every hidden state out of a transformer using a linear layer. There are two ways to do this, broadcast using matmaul or use einsum. I found that using einsum was about 4x faster.
In this colab notebook, I set up the code for each, and profile each method.
For those who don’t want to open colab, this are the equivalent operations I am comparing
At first I thought it may have been the squeeze and unsqueeze operations, so I did a version where the inputs were pre unsqueezed and no squeeze operations after, and the timings were about the same.
I don’t know why einsum() is faster than matmul(), but we have seen things
like this before, for example, as in this post:
(We’ve also seen cases where einsum() is unexpectedly and unreasonably
slow.)
As an aside, I might guess that it would be better to describe this as an
(unexpected) slowdown in matmul(), rather than as a speedup in einsum().
Have you considered comparing the matmul() timings with a loop version?
I think I may have found another reason: I suspect einsum uses more GPU memory somehow in this case, because the max inference batch size I can run is lower.