How to express tensor contraction efficiently? #einsum

I have 2 tensors of the following dimensions:

A: n x i x o
B: n x b x i

and I would like to compute the tensor C of dimension n x b x o. Here, n denotes the number of feature maps, o is the output dimension, i is the input dimension, and b is the batch size.

Think of A, B, C as stacks of matrices. The operation I’m looking for is essentially map-wise matrix multiplies.

What would be the most GPU-efficient way to express my computations.


C = torch.einsum('nio,nbi->nbo', [A, B])

do the trick?

Would that be correct and reasonably efficient? If not, what’s a better alternative?

Note that I can change the orders of the dimensions to make the computation more efficient if necessary.

This particular form is a batch matrix multiplication if you swap A and B, so you could use torch.bmm or torch.matmul directly.
Einsum will (currently) reduce to bmm, so it’ll be similar in terms of performance, but with a few permutations.

Best regards