Computation and memory issues of cholesky and cholesky_inverse

Hi, I’ve been looking into torch.cholesky and torch.cholesky_inverse for large matrices and have experienced some problems with them.

When I run the following script:

import torch

A = torch.randn(30000, 40000, device='cuda')
A = A @ A.t()

%time L = torch.cholesky(A)
del A
%time A_inv = torch.cholesky_inverse(L)

it prints:

CPU times: user 24.5 s, sys: 352 ms, total: 24.9 s
Wall time: 22 s
CPU times: user 3min 36s, sys: 969 ms, total: 3min 37s
Wall time: 42.7 s

I’m really confused about the long run time of cholesky_inverse. As I understand it computes L_inv.t() @ L_inv and L_inv should be computed in O(n^2) time as L is a lower triangular matrix. However, it’s running much longer than the cholesky factorization which needs O(n^3) time. With these sizes, I would expect cholesky_inverse to take an instant.

My another doubt is about the memory consumption of cholesky. I thought that if I do torch.cholesky(A, out=A) on a 50000 x 50000 float matrix A (9.31GB) then I should not run out of memory on a 12GB GPU. However, I can see that pytorch tries to allocate another 9.31GB which throws an OOM error. My guess is that even though Cholesky decomposition can be calculated in-place, the corresponding MAGMA routine doesn’t do exactly that and allocates another 9.31 GB of memory for performance reasons?

I realized that the very slow run times were due to another user occupying the GPU node and using up all the CPU resources. Now for the same setup cholesky takes ~2.5s and the cholesky_inverse takes ~0.6s (still I would expect it would be faster, but it’s not a problem now).

I’m still however wondering if it’s possible to perform Cholesky inplace without allocating extra memory?

1 Like

Another update. It turned out that what I wrote in the previous post was not really the case. I forgot about CUDA asynchronous calls and thus the timings were wrongs. I added cuda.synchronize and ran the following code:

import torch
from time import time

A = torch.randn(30000, 35000, device='cuda')
A = A @ A.t()
torch.cuda.synchronize()

t = time()
A_inv = torch.inverse(A)
torch.cuda.synchronize()
print('torch.inverse', time() - t)

t = time()
torch.cholesky(A, out=A)
torch.cuda.synchronize()
print('torch.cholesky', time() - t)

t = time()
torch.cholesky_inverse(A, out=A)
torch.cuda.synchronize()
print('torch.cholesky_inverse', time() - t)

Now, the output is:

torch.inverse 16.596264123916626
torch.cholesky 4.46283221244812
torch.cholesky_inverse 90.9823350906372

So again, cholesky_inverse takes extremely long. I would expect it to be many times faster than cholesky but it doesn’t seem to be the case. What is going on here?

Hi,

We actually use magma for all the heavy lifting here.
In particular, for the non-batched version, we directly call their function. Relevant code here.

I am not sure how to explain this though :confused: