Torch.einsum is around ~4x faster than broadcasting torch.matmul for my use case

My use case is to project the hidden state of every hidden state out of a transformer using a linear layer. There are two ways to do this, broadcast using matmaul or use einsum. I found that using einsum was about 4x faster.

In this colab notebook, I set up the code for each, and profile each method.

For those who don’t want to open colab, this are the equivalent operations I am comparing

torch.matmul(phase2.unsqueeze(2), weight.unsqueeze(2)).squeeze(2).squeeze(2)

and

torch.einsum('abc, bc -> ab', phase2, weight)

At first I thought it may have been the squeeze and unsqueeze operations, so I did a version where the inputs were pre unsqueezed and no squeeze operations after, and the timings were about the same.

What could be causing the 4x speedup in einsum?

Hi Santosh!

I don’t know why einsum() is faster than matmul(), but we have seen things
like this before, for example, as in this post:

(We’ve also seen cases where einsum() is unexpectedly and unreasonably
slow.)

As an aside, I might guess that it would be better to describe this as an
(unexpected) slowdown in matmul(), rather than as a speedup in einsum().
Have you considered comparing the matmul() timings with a loop version?

Best.

K. Frank

1 Like

Thanks for the detailed answer K.Frank. Will look into a loop version.

I think I may have found another reason: I suspect einsum uses more GPU memory somehow in this case, because the max inference batch size I can run is lower.

Is there a way to profile this to be sure?