I want to perform the following matrix multiplication,
(k, N, N) @ (b, N, N) @ (k, N, N) -> (b, N, N)
which can be achieved in many different ways using the various pytorch matrix multiplication functions. But what is the fastest way to achieve that for large matrices? Below is a benchmark of three different ways (tensordot
seems to be the fastest for now) but I can’t help but wonder: are there other faster implementations I’m missing?
import torch
import timeit
k = 3
b = 4
N = 100
device = torch.device("cuda")
dtype = torch.complex64
A = torch.randn(k, N, N, dtype=dtype, device=device)
X = torch.randn(b, N, N, dtype=dtype, device=device)
# 83.8 µs ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit (A[:, None] @ X[None, :] @ A[:, None].adjoint()).sum(0)
# 240 µs ± 211 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit torch.einsum('aij,bjk,akl->bil', A, X, A.adjoint())
# 71.9 µs ± 84.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit torch.tensordot(torch.tensordot(A, X, dims=([2], [1])), A.adjoint(), dims=([0, 3], [0, 1]))