Fastest way to batch multi_dot

I want to perform the following matrix multiplication,

(k, N, N) @ (b, N, N) @ (k, N, N) -> (b, N, N)

which can be achieved in many different ways using the various pytorch matrix multiplication functions. But what is the fastest way to achieve that for large matrices? Below is a benchmark of three different ways (tensordot seems to be the fastest for now) but I can’t help but wonder: are there other faster implementations I’m missing?

import torch
import timeit

k = 3
b = 4
N = 100
device = torch.device("cuda")
dtype = torch.complex64

A = torch.randn(k, N, N, dtype=dtype, device=device)
X = torch.randn(b, N, N, dtype=dtype, device=device)

# 83.8 µs ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit (A[:, None] @ X[None, :] @ A[:, None].adjoint()).sum(0)
# 240 µs ± 211 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit torch.einsum('aij,bjk,akl->bil', A, X, A.adjoint())
# 71.9 µs ± 84.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit torch.tensordot(torch.tensordot(A, X, dims=([2], [1])), A.adjoint(), dims=([0, 3], [0, 1]))

Any idea on this would be really appreciated ! It is a core operation of a really time-sensitive part of the my code, so any gain is really big.

I’m going with the tensordot approach for now which seems fastest, but it doesn’t feel efficient to run two operations for a somewhat symmetric operation such as

(A[:, None] @ X[None, :] @ A[:, None].T).sum(0)

Thanks for your help :slight_smile: