Performance of broadcast matrix solvers

TLDR: Computing an inverse and multiplying is much faster than using a solver when A is (1, 10, 10) and b is (5000, 1, 10). Using a solver is fastest when A is (5000, 10, 10) and b is (1, 1, 10). This difference is less pronounced on cuda than cpu. However, when there are multiple leading dimensions on A and B, it seems like the inverse-multiply is still the fastest way to compute a solution in torch on either device.

I have started to use torch in the last year, mainly to replace numpy and scipy in linear algebra calculations. The broadcasting capabilities of torch are really eye opening, it’s cool how many operations there are that can generalize to N dimensions!

I am confused about the performance of broadcasting with torch solvers, however. I am trying to convince my team that they should be using matrix solvers, instead of calculating an inverse matrix and then multiplying. At issue here is a performance critical inverse, where A is an hermitian matrix and b is a set of vectors, with a few shared and broadcast dimensions outside of the final dimensions.

I believe the theory here is clear, we should use a Cholesky decomposition and solver. In practice this does not play out, which may be related to the small matrix sizes (matrices in A are 10 x 10). However, the issue persists when I tried torch.linalg.solve and an explicit LU decomposition.

The problem may be in the broadcasting itself. It seems like the solvers do not perform well when the decomposition result is used many times, i.e. when A has a dimension broadcast up to a large dimension in b.

import torch
import numpy as np
import time

dim_0 = 10
# Switch dimensions 1&2 to change performance
dim_1 = 1
dim_2 = 5000
device = torch.device('cpu')

def make_array(device):    
    B = torch.randn((dim_1, dim_0), dtype=torch.complex64, device=device)
    snapshots = torch.randn((dim_2, dim_0), dtype=torch.complex64, device=device)
    A = torch.einsum('...i, ...j->...ij', snapshots, snapshots.conj())
    # make sure everything is contiguous
    A = A.contiguous()

    # Create a diagonal matrix with the loading value
    loading = 5e-1 * torch.eye(A.shape[-1], dtype=torch.complex64, device=device)
    # Reshape the diagonal matrix to broadcast with the CSDM tensor
    loading_shape = (A.ndim - 2) * (1,) + (A.shape[-1], A.shape[-1])
    loading = loading.reshape(loading_shape)
    A = A + loading
    return A, B

def time_solve(solver):
    num_loops = 20
    total_time = 0.
    for i in range(num_loops):
        A, B = make_array(device)
        torch.cuda.synchronize()
        tmp = time.time()
        solver(B, A)
        torch.cuda.synchronize()
        if i > 0:
          total_time += time.time() - tmp
    return total_time * 1e6 / num_loops

def solve_cholesky(steering_vectors, csdm):
    csdm_chol = torch.linalg.cholesky(csdm)
    beams_chol = torch.cholesky_solve(steering_vectors.unsqueeze(-1), csdm_chol).squeeze(-1)

def solve(steering_vectors, csdm):
    beams = torch.linalg.solve(csdm, steering_vectors.unsqueeze(-1)).squeeze(-1)

def solve_inverse(steering_vectors, csdm):
    csdm_inverted = torch.linalg.inv(csdm)
    beams_inv = torch.einsum("...ij,...j->...i", csdm_inverted, steering_vectors,)

print(time_solve(solve_cholesky))
print(time_solve(solve))
print(time_solve(solve_inverse))

After more digging, it seems like there are times when computing a matrix inverse is preferred over a decomposition. Repeated calculations of small systems of equations with the same A seems to fit cleanly into this use case.