When we have some n-by-n matrix A and we want to compute A.inv() @ b for a vector b, rather than inverting A, we usually use torch.solve(A, b). This returns the result of A.inv() @ b without actually inverting A, and it is often faster and more numerically stable.
I wonder whether a similar approach exists to computing A.inv() @ B @ A.inv().T, where both A and B are symmetric positive definite.
I tried using torch.linalg.solve twice, but this is often slower than inverting A. I also tried using the torch.cholesky_solve, but this is also slower than just inverting A.
Is there something about this problem that makes inverting A the best option?
Example code:
import torch
n_dim = 20
n_points = n_dim * 50
X1 = torch.randn(n_points, n_dim)
A = X1.T @ X1
X2 = torch.randn(n_points, n_dim)
B = X2.T @ X2
def fun_invert(A, B):
A_inv = torch.linalg.inv(A)
return A_inv @ B @ A_inv
def fun_solve(A, B):
C = torch.linalg.solve(A, B)
return torch.linalg.solve(A, C.T).T
def fun_chol(A, B):
L = torch.linalg.cholesky(A)
C = torch.cholesky_solve(B, L)
return torch.cholesky_solve(C.T, L).T
%timeit fun_invert(A, B)
%timeit fun_solve(A, B)
%timeit fun_chol(A, B)
The short story is that cholesky_solve() beats inv() both for larger matrices and when
run on the gpu. Presumably the theoretical disadvantage of inv() is masked by various
overheads when the dimensions of the matrices are relatively small.
This is not surprising. The benefits of using solve() instead of inv() are reduced when
you are solving for “multiple right-hand sides,” that is, when your B is a matrix rather than
a vector. When you invert A, you are solving for the n columns of the inverse of A and
you can formulate this as solving for n right-hand sides. So calling solve() for B, where B is a square matrix of dimension n x n is broadly the same as inverting A. But for your
use case, you have to call solve() twice, whereas you only have to invert A once, reusing
the result.
As an aside, I don’t think A being symmetric positive definite helps for your use case.
It does guarantee that A is invertible, but I’m not aware of any direct (i.e., non-iterative)
solvers that make use of A being SPD.
I can reproduce your result of cholesky_solve() being slower (when n is relatively small).
In general, however, cholesky_solve() should be faster.
The harder part of the full cholesky solution is calling cholesky() (the decomposition) and
this part is called once and reused (sort of like reusing the inverse). You do have to call cholesky_solve() twice, but this is the easier part of the full cholesky solution.
It looks to me like inverting A wins just due to the vagaries of how the computation for
smaller values of n interacts with the cpu (and gpu) floating-point pipelines.
I find that I can reproduce your result that inverting A is fastest for n = 20 (your n_dim)
running on the cpu, but running with significantly larger n (in the example below, n = 200)
and / or on the gpu, the cholesky approach is faster.
I’ve timed n = 20 vs. n = 200 and cpu vs. gpu with a timing script based on your code.
Here is the script:
import torch
print (torch.__version__)
import time
torch.manual_seed (2025)
device = 'cpu'
if torch.cuda.is_available(): device = 'cuda'
if device == 'cuda':
print (torch.version.cuda)
print (torch.cuda.get_device_name())
# n_dim = 20
n_dim = 200
n_points = n_dim * 50
n_batch = 10
print ('device:', device, ' n_dim:', n_dim)
n_warm = 5
n_time = 1000
# create batches of SPD A and B
X1 = torch.randn(n_batch, n_points, n_dim, device = device)
A = X1.permute (0, 2, 1) @ X1
X2 = torch.randn(n_batch, n_points, n_dim, device = device)
B = X2.permute (0, 2, 1) @ X2
def fun_batch_invert (A, B):
for a, b in zip (A, B):
a_inv = torch.linalg.inv (a)
_ = a_inv @ b @ a_inv
def fun_batch_solve (A, B):
for a, b in zip (A, B):
c = torch.linalg.solve (a, b)
_ = torch.linalg.solve (a, c.T).T
def fun_batch_chol (A, B):
for a, b in zip (A, B):
l = torch.linalg.cholesky (a)
c = torch.cholesky_solve (b, l)
_ = torch.cholesky_solve (c.T, l).T
# time invert
for _ in range (n_warm):
fun_batch_invert (A, B)
if device == 'cuda': torch.cuda.synchronize()
tStart = time.time()
for _ in range (n_time):
fun_batch_invert (A, B)
if device == 'cuda': torch.cuda.synchronize()
tEnd = time.time()
print ('time_inv: {:.4f}'.format (1000 * (tEnd - tStart) / (n_batch * n_time)))
# time solve
for _ in range (n_warm):
fun_batch_solve (A, B)
if device == 'cuda': torch.cuda.synchronize()
tStart = time.time()
for _ in range (n_time):
fun_batch_solve (A, B)
if device == 'cuda': torch.cuda.synchronize()
tEnd = time.time()
print ('time_sol: {:.4f}'.format (1000 * (tEnd - tStart) / (n_batch * n_time)))
# time chol
for _ in range (n_warm):
fun_batch_chol (A, B)
if device == 'cuda': torch.cuda.synchronize()
tStart = time.time()
for _ in range (n_time):
fun_batch_chol (A, B)
if device == 'cuda': torch.cuda.synchronize()
tEnd = time.time()
print ('time_cho: {:.4f}'.format (1000 * (tEnd - tStart) / (n_batch * n_time)))
And here is its output for the four different runs:
2.7.1+cu128
device: cpu n_dim: 20
time_inv: 0.0244
time_sol: 0.0364
time_cho: 0.0341
2.7.1+cu128
12.8
NVIDIA RTX 3000 Ada Generation Laptop GPU
device: cuda n_dim: 20
time_inv: 0.1210
time_sol: 0.1847
time_cho: 0.0851
2.7.1+cu128
device: cpu n_dim: 200
time_inv: 0.5564
time_sol: 0.8801
time_cho: 0.4294
2.7.1+cu128
12.8
NVIDIA RTX 3000 Ada Generation Laptop GPU
device: cuda n_dim: 200
time_inv: 0.4634
time_sol: 0.8400
time_cho: 0.3604
A couple of comments:
The cpu timings jump around a lot more than the gpu timings. I imagine this is because
there are various “random” processes running in the background competing for the cpu.
Empirically, n = 200 is still rather small in the context of the gpu. The benefit of the gpu
becomes dramatically more pronounced for larger n.