Matrix Multiplication Along Axis

I have a bunch of matrices M1, M2, …, Mk in a tensor of shape (k, d, d). I want to compute the matrix product M1 @ M2 @ … @ Mk. This would give an output tensor of shape (d, d). Is there a fast way to do this in PyTorch?

I looked at some questions that claim to be about this How do do matrix multiplication (matmal) along certain axis? and Matrix multiplication along specific dimension , but they seem to be concerned with normal tensor contraction, as can be done with einsum. I don’t think einsum can solve my problem. Or am I missing something?

Hi @Thomas_Ahle,

You can use torch.linalg.multi_dot, which will compute the optimal ordering of your k number of matrices. Although, you’ll need to pass your (k, d, ,d) tensor as a list of matrices instead. That can be done through list comprehension.

import torch
matrices = torch.randn(4, 6, 6)
torch.linalg.multi_dot([m for m in matrices])
"""
returns 
tensor([[-20.4766, -18.9575,   2.6211,  -9.0951,   7.8230,   6.8106],
        [ 13.0261,   5.3627, -18.5586,  -5.2984, -24.7720,   0.4994],
        [  6.4090,  14.7333,   2.8740,   3.7405,   5.7821,  -7.2868],
        [ -2.2964,   0.6236,   8.3580,  -0.3444,   5.8850,  -6.5138],
        [-14.0015,  -8.9284,  15.0120,   0.9839,  18.0842,  -0.1127],
        [-10.8341,  -0.6622,   9.4417,   2.6874,  19.8331,   1.5707]])
"""

Hi @AlphaBetaGamma96, the ordering doesn’t matter so much for me, since the matrices are all square.

What I want to achieve is avoiding the python loop.

I tried timing the simple

def mymul(mats):
  res = mats[0]
  for mat in mats[1:]:
    res @= mat
  return res

against multi_dot:

>>> matrices = torch.randn(10**3, 2, 2)
>>> %timeit mymul(matrices)
3.72 ms ± 24.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit torch.linalg.multi_dot(list(matrices))
585 ms ± 14.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

So torch.linalg.multi_dot(matrices) is more than 100x slower than the naive python version. I also tried 10^4 matrices, but then multi_dot just crashed.