Torch.linalg.eig parallelisation

Hello,

I am using torch for some computations, but am struggeling with improving the performance of my code. I need to get the eigenvalues of some matrizes in the order of magnitude of 6000x6000. It works fine with torch.linalg.eig, but is rather slow, compared to the rest of my computations. If I understand correctly, this is largely because part of the computations need to be performed on CPU.

Now what I don’t understand is, that the .eig function does not seem to parallelize at all. If I hand it a batch of matrices, the calculation takes pretty much exactly the same amount of time, compared to when I call eig for each matrix individually. This is kind of frustrating, as I have plenty of CPU cores as well as RAM and even multiple GPUs available.

Additionally, I observed a slowdown with to many threads. If I use more than ~16 threads there is actually a slowdown of the eig calculation. A speedup can only be observed until about 8 threads. Is this to be expected?

I am unsure how to tacle this problem. Would it be best to try to parallelize using DDP? Or Is this not suitable for linalg calculations. Depending on the problem, I will have about 300 of the matrices mentioned, which should all be able to be calculated individually. I would preferre to have gradients available if that is possible with parallelization.

Best

Hi @Johannes99,

You could try using torch.vmap to vectorize over the torch.linalg.eig call, although you’ll probably have to chunk this operation to avoid an OOM error. The torch.linalg.eig method does run on CUDA, so I’d check CUDA is available in the torch via torch.cuda.is_available() docs here

Just a note, you can take the gradients of a eigendecomposition with respect to the matrix so long as the eigenvalues are unique. There’s more info on this on the respective doc page here. Futhermore, you can also use the has_aux flag when computing gradient with torch.func.jacrev to return the output (i.e. the eigendecomposition) as well as its derivative in a single call, docs here

Also, more threads doesn’t necessarily mean better performance if the overhead per thread is greater than the work it’s used to compute.

Thanks! I will look into that.

In the meantime I have tried around with getting multiprocessing working. Actually it is seems to be super important to have MKL configured correctly. The automatic threading of MKL seems to be very bad at managing multiple concurrent threads. So with some trial and error (and some help from ChatGPT) I have a solution that feels kind of unsatisfying, but is at least usable:

The parts of the algorithm that are run on the cpu (I am not sure of the backend, but cuSolvers geev implementation mentions a hybrid cpu-gpu architecture, so maybe that is the backend?) only seem to listen to the MKL_NUM_THREADS sometimes. Sometimes during execution I can see a lot of jupyter kernel threads lighting up the entire 32 available CPU cores. If I set my MKL_NUM_THREADS to values larger than 2 (given I have 2 cuda instances running, I am figuring this should only come out to a max of 4 cores getting hit at once) the system basically becomes frozen during Eigenvalue computation. This is super confusing to me. Also these high thread spikes seem to take more time the more threads I set, but are still visible, even with the threadcount set to 1. Even with just one GPU, more than 3 threads seem to slow down performance notably. On the other hand, during large parts of the calculation I can only see one, or two in the two GPU usecase, cores working. I got additional performance and consistency out of my system by following along this guide https://www.intel.com/content/www/us/en/docs/onemkl/developer-guide-windows/2023-0/call-onemkl-functions-from-multi-threaded-apps.html

currently I have set:

os.environ["MKL_THREADING_LAYER"] = "intel"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["MKL_DYNAMIC"] = "FALSE"
os.environ["KMP_BLOCKTIME"] = "0"
os.environ["OMP_WAIT_POLICY"] = "PASSIVE"

Which seems to get the most out of my system (which is very beefy, Intel Xeon Platinum 8358 CPU, 192GB RAM, 2 A100 80GB)

It seems absolutely counterintuitive to me how more threads can actually slow down a calculation, even if I am not multithreading my workload myself. If anyone is more familiar with the backends I would love to gain some insight on the matter.

Hi Johannes!

First, eigendecomposition is a rather expensive operation, so it’s not surprising
that it adds cost to your overall computation.

Second, if you only need the eigenvalues (and not the eigenvectors), you might
try linalg.eigvals(). This might be somewhat faster (but probably not a whole
lot as it still has to do much of the same work as linalg.eig()).

I don’t know that this is the dominant cost of linalg.eig(), but I do believe
it is a non-trivial part of it. (I don’t understand this part of the algorithm,
but I take it at face value that it’s been implemented efficiently and for good
reason.)

This doesn’t surprise me. The individual steps within a typical eigenvalue
algorithm aren’t really things that parallelize well over the outer batch
dimension. Furthermore, when I run eig() (on reasonably large matrices),
pytorch seems to do a fine job of saturating the gpu and / or cpu cores.
Given that, parallelizing over the batch dimension would just be another
way – not necessarily better – of keeping the cores full.

First note that if your batch of matrices is a single tensor, and therefore on
a single gpu, eig() will not be able to (and your probably wouldn’t want it to)
parallelize the batch computation across multiple gpus.

But if you have multiple gpus available (on the same cpu / “node”), you should
be able to get some speed-up by parallelizing it by hand. Note that moving
tensors from one gpu to another preserves the computation graph, so this approach
won’t interfere with your gradient computation.

I don’t have multiple gpus, but I can demonstrate this approach with two distinct
devices, namely the cpu and gpu. Here is an illustrative script:

import torch
print (torch.__version__)
print (torch.version.cuda)

_ = torch.manual_seed (2025)

devices = ['cpu', 'cuda']

t = torch.randn (2, 4, 4, requires_grad = True, device = devices[0])   # batch of two 4x4 matrices

l = t.unbind()                                                         # split batch into tuple of 4x4 matrices
l = [m.to (d)  for m, d in zip (l, devices)]                           # list of matrices on multiple devices
e = [torch.linalg.eig (m).eigenvalues  for m in l]                     # list of eigenvalue vectors

# check devices
for  f in e:  print (f.device)

loss = torch.zeros (1, device = devices[0])                            # accumulate batch loss on device[0]
for  f in e:
    loss += f.norm().to (devices[0])                                   # note that loss is phase invariant

loss.backward()

# check that gradients flow back to original tensor on devices[0]
print ('t.grad = ...')
print (t.grad)

And here is its output:

2.8.0+cu129
12.9
cpu
cuda:0
t.grad = ...
tensor([[[-0.2030, -0.2141,  0.4669,  0.2993],
         [ 0.4673,  0.2443,  0.1762, -0.7616],
         [ 0.1722, -0.2288, -0.3123, -0.0708],
         [ 0.2497, -0.0167,  0.6192,  0.1789]],

        [[ 0.0458,  0.1440,  1.2318,  0.4386],
         [ 0.1663,  0.1700,  0.4918, -0.0615],
         [-0.0775,  0.3317,  0.2586,  0.1593],
         [-0.4495, -0.3338, -1.1500, -0.4152]]])

(There might be a way to make pytorch do something like this automatically, but
I don’t know of one.)

I’ve also run this with a batch of two 6000x6000 matrices. It takes about a
minute, with significant usage of both the cpu and gpu (but note that, as written,
this script generates the random matrices on the cpu).

As an aside, as Alpha noted, this is not surprising. Leaving pytorch and eig()
aside, if you tell your system to use to few threads, you won’t use all of your
cores to their capacity, while if you tell it to use too many threads (and it
actually listens to you and does so), your system will spend time swapping
threads on and off various cores – to be “fair” to all of the threads – actually
slowing things down.

As a first draft, at least with modern software like pytorch, you generally don’t
manage threading yourself unless you have a good reason to think that you
can do better. For example, my example script saturates my gpu part of the
time, so pytorch is presumably launching multiple overlapping gpu “kernels” to
perform the gpu version of the eig() computation. My cpu also becomes mostly
saturated part of the time (by which I mean that many of the cpu cores are mostly
saturated), so pytorch is presumably also distributing the cpu version of the
eig() computation across multiple cpu threads so that multiple cpu cores can
be used simultaneously.

I’m not particularly familiar with DDP, but I doubt that it would be useful for
your use case. In my mind, DDP would be appropriate for much coarser-grained
parallelism.

(If you have multiple servers, perhaps with more than one gpu per server, rather
than a single server with multiple gpus, something like DDP, which parallelizes
across multiple processes, would be necessary, but it introduces a lot of
communication overhead, which would seem to be a bad fit for your use case.)

To recap: With the exception of splitting the matrices in your batch across
multiple gpus (by hand, if there isn’t some pytorch feature that could do that
for you), I think the best parallelization you will get is that that is already
implemented in pytorch’s eig(). By now, eig() is a pretty standard pytorch
building block and is most likely already well optimized.

Best.

K. Frank

Thank you very much for the insights!

Have you compared your code against just calculating the eigenvalues one by one?

If I run it on my machine I see just one CPU core being used for large parts of the calculation. My exact code is:

import torch
print (torch.__version__)
print (torch.version.cuda)
from time import time
#torch.backends.cuda.matmul.allow_tf32 = False

_ = torch.manual_seed (2025)

devices = ['cuda:0', 'cuda:1']

t = torch.randn (2, 6000, 6000, requires_grad = True, device = devices[0], dtype = torch.complex128)   # batch of two 4x4 matrices

l = t.unbind()                                                         # split batch into tuple of 4x4 matrices
l = [m.to (d)  for m, d in zip (l, devices)]                           # list of matrices on multiple devices
start = time()
e = [torch.linalg.eig (m).eigenvalues  for m in l]                     # list of eigenvalue vectors


# check devices
for  f in e:  
    print (f.device)
    print(f)
print(time() - start)

loss = torch.zeros (1, device = devices[0])                            # accumulate batch loss on device[0]
for  f in e:
    loss += f.norm().to (devices[0])                                   # note that loss is phase invariant

loss.backward()

# check that gradients flow back to original tensor on devices[0]
print ('t.grad = ...')
print (t.grad)


On the other hand, if I start the calculations from two threads I get roughly a 1.9x speedup compared to single GPU execution. This is the code I am using to do so:

def run_two_gpus(X: torch.Tensor, devices: list[torch.device]):
    N, n = X.shape[0], X.shape[-1]
    c_dtype = complex_dtype_for(X.dtype)
    eigvals = torch.empty((N, n), dtype=c_dtype)
    eigvecs = torch.empty((N, n, n), dtype=c_dtype)

    idx_chunks = torch.chunk(torch.arange(N), len(devices))
    t0 = time.perf_counter()
    with ThreadPoolExecutor(max_workers=len(devices)) as pool:
        futures = [pool.submit(worker, X, idx_chunks[k], devices[k], c_dtype, n) for k in range(len(devices))]
        for fut in futures:
            idxs, vals_part, vecs_part = fut.result()
            if idxs.numel() > 0:
                eigvals[idxs] = vals_part
                eigvecs[idxs] = vecs_part
    for dev in devices:
        if dev.type == "cuda":
            torch.cuda.synchronize(dev)
    t1 = time.perf_counter()
    return eigvals, eigvecs, t1 - t0

I think the key might be in the part that is described in the documentation of the linalg.eig function as “When inputs are on a CUDA device, this function synchronizes that device with the CPU.” I am not familiar with the matter, but my suspicion is, that only one cuda device can be synchronized to the CPU at a time or something of that matter.

On the part about using to little or to many threads… I am just wondering how on 16 available cores a theoretical maximum of 12 threads could lead to oversubscription of the available cores.

Hi Johannes!

The speedup you get from using python-level multi-threading is very interesting
and led me to make the following observations.

I have only one gpu, so I can’t explore running on multiple gpus, with or without
multi-threading. I do, however, see a significant speedup using multi-threading on
both the cpu and on the (single) gpu.

Here is my timing script:

import torch
print (torch.__version__)
print (torch.version.cuda)

import concurrent.futures
from time import time

_ = torch.manual_seed (2025)

n = 6000
nb = 4

nWarm = 3
nTime = 10

def batchEig (tb):
    return  torch.linalg.eig (tb)

def listEig (tb):
    return  [torch.linalg.eig (t)  for t in tb]

def poolEig (tb):
    with concurrent.futures.ThreadPoolExecutor(max_workers = 8) as pool:
        futures = [pool.submit (torch.linalg.eig, t)  for t in tb]
        return  [fut.result()  for fut in futures]

print ('nb:', nb)
print ('n: ', n)
for  dev in ('cpu', 'cuda'):
    print (dev, 'timings:')
    tBatch = torch.randn (nb, n, n, device = dev, requires_grad = True)   # batch of nb nxn matrices
    for  eFunc in (batchEig, listEig, poolEig):
        for  i in range (nWarm):
            eig = eFunc (tBatch)
        if  dev == 'cuda':
            torch.cuda.synchronize()
        t0 = time()
        for  i in range (nTime):
            eig = eFunc (tBatch)
        if  dev == 'cuda':
            torch.cuda.synchronize()
        t1 = time()
        print ('{:9s} {:8.2f} sec'.format (eFunc.__name__ + ':', (t1 - t0) / nTime))

And here are the resulting timings:

2.8.0+cu129
12.9
nb: 4
n:  6000
cpu timings:
batchEig:    94.46 sec
listEig:     95.00 sec
poolEig:     64.78 sec
cuda timings:
batchEig:    61.81 sec
listEig:     62.24 sec
poolEig:     42.38 sec

I would say that this is an explicit performance bug in pytorch’s linalg.eig()
implementation (and probably also some other linalg implementations`). If I can
achieve meaningful speedup of an implementation that observationally already
uses multiple cpu cores by wrapping it in python-level multi-threading, then that
implementation is obviously missing some straightforward performance enhancements.

You might consider filing a github issue about this.

Some further commentary:

When I run my timings, I do not see the bulk of the computation being throttled
down to a single cpu core (or a small number of cpu cores). For most of the
computation I see many cores being mostly saturated. I do see – whether the
tensor being processed is on the cpu or gpu – brief periods where the cpu usage
seems to be restricted to one, or perhaps two, cpu cores, but this is at most a
smallish fraction of the overall computation.

I do see significant speedup applying eig() to a cuda tensor vs a cpu tensor.
However, when monitoring the gpu and the cpu cores I see only brief bursts of
gpu activity embedded in fairly constant heavy cpu usage. This seems plausible;
the small bursts of gpu activity could certainly be performing operations that
would take much longer on the cpu, reducing the load on the cpu somewhat.

I speculate that the cuda implementation lives in libcusolver.so and that this is
a proprietary nvidia library whose source code is not visible. Absent the source
code, whether I just haven’t been able to find it or whether it’s unavailable, I’m
limited to speculation rather than concrete analysis.

For this example, I see the gpu reducing eig()'s run time by about 35%. While
useful, this is quite modest compared to the gpu speedup I’ve seen for other
tensor and linear-algebra computations.

Best.

K. Frank

It is great to see someone being able to reproduce the issue.

I will probably take this to GitHub. Actually I think I know now, where the issue comes from:

Torch only has a general eigenvalue solver from magma implemented (I cannot find any calls to the CuSolver implementation in the source). As far as I can tell (without any prior knowledge on the matter), magma just does some preparation on the GPU and then just uses CPU lapack, and, in my case probably mkl as the true backend. If I would have to guess, my mkl on the cluster is just messed up, that is why I get this weird reaction to the threads. I actually have better performance on my home machine without having to fiddle around with mkl variables.

So in my mind the easiest solution would be to parallelize the magma calls if that is possible in the source, or, as magma is open source, doing so at a magma level.

Hi Johannes!

My problem is that I can’t find a call to any cuda eig() implementation in the
source. (In general I find it quite hard to navigate the pytorch source tree.)

Coming back to the possible case of cusolver, I see in my (most recent) conda
environment:

ls ~/miniconda3/envs/2_8_0/lib/python3.13/site-packages/nvidia/cusolver/lib

libcusolverMg.so.11  libcusolver.so.11

Looking inside (for what it’s worth):

strings libcusolver.so.11 | grep -i geev | sort | uniq

cusolverDnXgeev
cusolverDnXgeev_bufferSize
/dvs/p4/build/sw/rel/gpgpu/toolkit/r12.9/cuSolver/src/cudense/geev.cu
geev_duration={}s
xgeev_local
@_ZN37_INTERNAL_49463fea_7_geev_cu_02aa27374cuda3std6ranges3__45__cpo4swapE
_ZN37_INTERNAL_49463fea_7_geev_cu_02aa27374cuda3std6ranges3__45__cpo4swapE
<ZN37_INTERNAL_49463fea_7_geev_cu_02aa27374cuda3std6ranges3__45__cpo4swapE[1

geev seems to be a name dating back to lapack (or before?) for something like
general eigenvector”.

Again, not necessarily relevant, but in the nvidia docs (in the 2.4.5.7. cusolverDnXgeev() section)
I find this:

Remark 3: geev is a hybrid CPU-GPU algorithm. Best performance is attained with pinned host memory.

So if pytorch really delegates eig() to cusolver for cuda tensors, this would
explain a lot. Whether a “hybrid algorithm” is just an expedient or is really
more or less the best one can do, I don’t know.

But regardless, this note in the high-level pytorch eig() docs:

When inputs are on a CUDA device, this function synchronizes that device with the CPU.

is misleading at best. It should really say something like “When inputs are on
a CUDA device, this function uses a hybrid algorithm that offloads about a third
of the cpu workload onto the CUDA device.”

At least on my system, for your use case, while it’s worth using the gpu to gain
a reduction in run time of about a third, there seems to be no benefit in trying
to use multiple gpus, as the gpu is not a bottleneck (and is only used at a tiny
fraction of its capacity).

(However, regardless of whether the hybrid algorithm is a sound approach, the
fact that both the cpu and gpu batch versions can be accelerated by wrapping
them in a high-level python thread pool indicates a performance bug, at least for
batched eig().)

Best.

K. Frank

My problem is that I can’t find a call to any cuda eig() implementation in the
source

Hi Frank, I think I can help with that. The relevant files seem to be in this folder pytorch/aten/src/ATen/native/cuda/linalg at main · pytorch/pytorch · GitHub. cuSolver implementations seem to be in the CUDASolver.cpp files, while Magma implementations seem to be in the BatchedLinearAlgebra.cpp file. There I see some functions called

 magmaEig<c10::complex<double>, double> 

I think these are the functions called when calculating eigenvalues on GPU. I seriously doubt, that cuSolver is used for this. I also modified one of the examples found here https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdnxgeev and tried it on my private machine. It is MUCH faster, just using a stoppwatch I see 17s for a 6000x6000 double matrix, where torch needs about 40s (the example is run in C++, so there might be some less overhead, due to python, but as torch is basically doing the calculations in C++, I would guess, that the overhead should not be that large).

Also, the load I can observe in taskmanager on my GPU is very different. The Magma call seems to just briefly load the GPU and then run on my CPU. I would guess that only the preparation of the calculations happens on the GPU. I think this also matches your observed speedup.

However, for my use case, there is a very good use for using multiple GPUs: It is to have multiple calculations running at once, and therefore loading my CPU better. I can process two matrices at once on my machine with basically no overhead visible, so I gain a 2x speedup from using two GPUs. I think it might also be possible to achieve this using multiple CUDA streams, but I am not sure of this, so I will need to try this out. This might also be the easiest way to speed up these calculations in the source: just parallelize them as good as possible. But for the sake of actually making this a CUDA algorithm, I think it would be best to actually implement the CuSolver Xgeev algorithm and make it use batched execution as far as possible.

Totally agree on that. It should be made way clearer in the docs what is actually going on, as the behaviour is very unclear.

Yes, I also think that it would definitely be worth trying to parallelize this on a C++ level and could lead to immense performance gains. As far as I know, eigenvalue calculations are also sometimes used in clustering and optimization problems, so this might not just be a niche problem, someone running a weird physics problem in Torch.

Best

Johannes

Hi Johannes!

Thanks for pointing me to this code – it’s what I had been looking for.

I agree with this. I do see calls for what look like symmetric / hermitian
eigendecomposition in CUDASolver.cpp but not for general eigendecomposition.

Nice find and test. Why would pytorch be ignoring cusolver?

If you do log a github issue, you should certainly propose this as a clear
improvement.

Just some guesswork on my part …

Based on the picture I got from watching the performance monitors for my timing
tests, it seems to me like the cpu is still the bottleneck. Do you have a sense
in your cusolver test of how your run time was split between the cpu and gpu?

If you have a batch of matrices (whether a list or a tensor) AND the cpu is not
the bottleneck, I do believe that using python threading (e.g., a thread pool) to
distribute the individual matrices across multiple gpus should work fine.

Whether using multiple cuda streams to process a batch on a single gpu would
be helpful, I just don’t know – I don’t have a clear enough understanding of the
internal gpu architecture.

I do believe that eigendecomposition is “commonly” used in neural networks and
machine learning. Sure, convolutions and point-wise activations are big use cases
and should be optimized first (as they have been), but eigendecomposition is
hardly something I would call niche. And outside of machine learning, it’s a real
workhorse of scientific computation.

Thanks for your analysis and testing!

Best.

K. Frank

Hi,

This is a very interesting question. I have three possible explainations and I only really like one of them:

  1. DNXgeev is not accurate/stable enough. I doubt this, as I have not found anything hinting at this, but it might be the case. This would leave optimizations in the Magma code, mainly making it batch properly, as the way to go. Also Xgeev does only calculate right eigenvectors. I think this is the default behavior of torch.linalg.eig also, but this could be another limitation.
  2. It might be harder/impossible to track the graph of DNXgeev. This might also apply to batched Magma geev implementations. I also don’t think this is the case, as I don’t see anything regarding the backwards calculation in the implementation in ATen/native…. Probably the backwards implementations of the linalg functions are completely seperate? I would love to learn about how this is implemented.
  3. cuSOLVER DnXgeev is just new and torch had to revert to Magma because of cuda limitations at some point and no one noticed cuSolver implemented a geev algorithm. I hope this is the case and there are some clues pointing to it: CuPy does not implement a geev algorithm. Also ChatGPT is reluctant to admit that eigenvalue problems can be solved (at least in large part) on CUDA devices. From the release notes of CUDA I can see that cusolverDnXgeev was only introduced in in CUDA 12.8 with some update (https://docs.nvidia.com/cuda/archive/12.8.0/cuda-toolkit-release-notes/). As far as I understand this has only been released this year. So it is brandnew and this is the most probable explaination.

I mean I did not (yet) do proper profiling of the execution (although I certainly plan on doing so). But it seemed to me that the execution was 99% on the GPU. I could not see any relevant CPU load, although this was with Windows and Chrome consuming some CPU… Also the drop in GPU load was very close to the start of the output of the results. So my guess would be that there is only minimal parts of the calculation getting done on CPU. If this holds true and the speed actually scales with GPU power (i will try my best to get DnXgeev running on an A100 and H100 and get some performance comparisons) this could really offer a huge speedup in many applications.

Well, it absolutely does for me. As said previously, I think this is some weird MKL bug on one of my machines, as I don’t see it anywhere else. Only one core is loaded during eig calculation. As soon as I use a second GPU, I see a second core being loaded and the calculations don’t really seem to slow each other down. I see a 2x in speed when doing so, so I think there must be some weird bottleneck. Additionally, it might be a good idea to try to paralellize multiple eig calculations at once when using the CPU. In my understanding, overhead should get much better, when calculating multiple indipendend problems at once compared to spreading a single problem over multiple cores.

I will try this. I just hope to get my GPU to feed multiple CPU cores with multiple matrices to speed up the calculations even more. But this can only be an intermediate solution. If there are no major problems, getting DnXgeev working in torch should be a priority.

I actually think it should not be extremely hard to implement this, as the calls seem to be largely the same for the hermetian solver, so probably we can just reuse a lot of the code. But I have no idea what is necessary for the backwards calls to work.

I will try to setup a working build environment to maybe do this myself, but I will definitely need some help doing this. I do however think that this will be absolutely worth it for Torch.

Best
Johannes

EDIT: As this turned into a backend development issue I have turned to the dev forum https://dev-discuss.pytorch.org/t/cusolver-dnxgeev-faster-cuda-eigenvalue-calculations/3248

EDIT2: I just tried the DnXGeev on an H100. Turns out it is a lot faster than my 4070. Got about 4 seconds for a 4000x4000 double matrix.

So, to put a nice close to this thread: I managed to actually integrate the new CuSolver path into pytorch. The pull request is just running final checks, and the code will be merged probably within the next hours. I think you could probably download it with one of the coming nightlies. As stated in the pr: The speedup is about 2x for consumer hardware, with high-performance hardware achieving a lot more (up to the 10x I observed on the H100, but that was only a quick test).

However, none of the eig backends are really parallelized currently, so work remains.

If anyone wants to take a look at the PR while it is merging: https://github.com/pytorch/pytorch/pull/166715

1 Like

Hi Johannes!

Good man!

Could you ping this thread when the public-facing nightly becomes available? (I don’t really
know how to track the pull / build / nightly process.) I’d like to run my timings on it.

Best.

K. Frank

Tbh I don’t really know either.

I found a GitHub Action that is named nightly that executes at 00:00 UTC, so it might be that the nightly from tomorrow morning already includes my changes. If you want to test it, watch your GPU usage. The old MAGMA backend will make it go up just briefly, while CuSolver will hit it pretty much the entire time it is calculating. If I test it myself, I will let you know once I am certain that it actually made it into the nightly.

Best

Johannes

Just checked: it is available in the 1104 nightly, if you download the nightly now you have the new CuSolver path, which is about 4x faster for fp32 and 1.3x faster for fp64 on RTX4070.

But the kicker comes on A100: 5x faster for fp32 and 7x faster for fp64 :slight_smile:

EDIT: Numbers for H100: 10x faster for fp32 and 15x faster for fp64. Crushing my best expectations.

Hi Johannes!

Thanks for the heads-up – looks great!

I tweaked my timing script to use torch.backends.cuda.preferred_linalg_library() to
switch between magma and cusolver to run timing comparisons. For my hardware and test
case, cusolver runs about 9 times faster than magma (and about 4.5 times faster than magma
with a threadpool applied to the batch dimension). (Using a threadpool has no discernible
effect on the cusolver timings.)

Here is the timing script:

import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())

import concurrent.futures
from time import time

torch.backends.cuda.preferred_linalg_library ('cusolver')   # to mask subsequent warning

_ = torch.manual_seed (2025)

n = 6000
nb = 4

nWarm = 3
nTime = 10

def batchEig (tb):
    return  torch.linalg.eig (tb)

def poolEig (tb):
    with concurrent.futures.ThreadPoolExecutor(max_workers = 8) as pool:
        futures = [pool.submit (torch.linalg.eig, t)  for t in tb]
        return  [fut.result()  for fut in futures]

print ('nb:', nb)
print ('n: ', n)
# for  dev in ('cpu', 'cuda'):
for  dev in ('cuda',):
    print (dev, 'timings:')
    if  dev == 'cuda':  backends = ('cusolver', 'magma')
    else:               backends = ('default',)             # just a placeholder
    for  be in backends:
        torch.backends.cuda.preferred_linalg_library (backend = be)
        if  dev == 'cuda':  print (be, '(preferred backend):')  
        tBatch = torch.randn (nb, n, n, device = dev, requires_grad = True)   # batch of nb nxn matrices
        tBatch = (tBatch + tBatch.mT) / 2                                     # make them symmetric
        for  eFunc in (batchEig, poolEig):
            for  i in range (nWarm):
                eig = eFunc (tBatch)
            if  dev == 'cuda':  torch.cuda.synchronize()
            t0 = time()
            for  i in range (nTime):
                eig = eFunc (tBatch)
            if  dev == 'cuda':  torch.cuda.synchronize()
            t1 = time()
            print ('{:10s} {:8.2f} sec'.format (eFunc.__name__ + ':', (t1 - t0) / nTime))

And here is its output running on the 1104 nightly:

2.10.0.dev20251104+cu130
13.0
NVIDIA RTX 3000 Ada Generation Laptop GPU
[W1104 19:29:09.865638494 Context.cpp:432] Warning: torch.backends.cuda.preferred_linalg_library is an experimental feature. If you see any error or unexpected behavior when this flag is set please file an issue on GitHub. (function operator())
nb: 4
n:  6000
cuda timings:
cusolver (preferred backend):
batchEig:     20.02 sec
poolEig:      19.89 sec
magma (preferred backend):
batchEig:    184.54 sec
poolEig:      93.99 sec

For completeness, I also ran it on the current stable release (without your fix), 2.9. This is
kind of boring because there is no cusolver (preferred_linalg_library() is a request,
not a demand), so the attempted cusolver timings just repeat the magma timings. Here is
that output:

2.9.0+cu130
13.0
NVIDIA RTX 3000 Ada Generation Laptop GPU
[W1104 22:09:54.363336967 Context.cpp:470] Warning: torch.backends.cuda.preferred_linalg_library is an experimental feature. If you see any error or unexpected behavior when this flag is set please file an issue on GitHub. (function operator())
nb: 4
n:  6000
cuda timings:
cusolver (preferred backend):
batchEig:    181.90 sec
poolEig:      91.24 sec
magma (preferred backend):
batchEig:    179.35 sec
poolEig:      91.05 sec

Thanks again! This is a big step forward.

From all of this I conclude that the magma implementation just isn’t very good. Not only does
it make poor use of the gpu, but its use of the overall system is poor, as shown by the fact that
using a crude, top-level python threadpool significantly improves performance.

(Perhaps these general results explain why the intel “xpu” branch of pytorch doesn’t yet really
support torch.linalg – apparently they would have to do significant work not to appear
equally silly.)

Best.

K. Frank