Cusolver and magma timings for linalg.eigh() (symmetric case)

Hello Forum (and Johannes)!

The short story:

cusolver outperforms magma for linalg.eigh() (for my test case).

cusolver saturates my gpu with light-to-moderate cpu usage, while magma only shows
moderate gpu usage and heavy (but not saturated) cpu usage.

(I monitor gpu usage with “nvidia-smi -l 1” and cpu usage with the ubuntu / gnome
graphical “System Monitor.”)

I became curious about the linalg “backends” because of this thread:

It appears that pytorch only supports magma for linalg.eig() (the general
asymmetric case), but poking around in the code Johannes showed me indicated
that both cusolver and magma are supported for linalg.eigh() (the symmetric
and hermitian case), so I ran some comparative timings.

Here is the timing script:

import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())

import concurrent.futures
from time import time

torch.backends.cuda.preferred_linalg_library ('cusolver')   # to mask subsequent warning

_ = torch.manual_seed (2025)

n = 6000
nb = 4

nWarm = 3
nTime = 10

def batchEigh (tb):
    return  torch.linalg.eigh (tb)

def poolEigh (tb):
    with concurrent.futures.ThreadPoolExecutor(max_workers = 8) as pool:
        futures = [pool.submit (torch.linalg.eigh, t)  for t in tb]
        return  [fut.result()  for fut in futures]

print ('nb:', nb)
print ('n: ', n)
for  dev in ('cpu', 'cuda'):
# for  dev in ('cuda',):
    print (dev, 'timings:')
    if  dev == 'cuda':  backends = ('cusolver', 'magma')
    else:               backends = ('default',)             # just a placeholder
    for  be in backends:
        torch.backends.cuda.preferred_linalg_library (backend = be)
        if  dev == 'cuda':  print (be, '(preferred backend):')  
        tBatch = torch.randn (nb, n, n, device = dev, requires_grad = True)   # batch of nb nxn matrices
        tBatch = (tBatch + tBatch.mT) / 2                                     # make them symmetric
        for  eFunc in (batchEigh, poolEigh):
            for  i in range (nWarm):
                eigh = eFunc (tBatch)
            if  dev == 'cuda':  torch.cuda.synchronize()
            t0 = time()
            for  i in range (nTime):
                eigh = eFunc (tBatch)
            if  dev == 'cuda':  torch.cuda.synchronize()
            t1 = time()
            print ('{:10s} {:8.2f} sec'.format (eFunc.__name__ + ':', (t1 - t0) / nTime))

And here is its output:

2.8.0+cu129
12.9
NVIDIA RTX 3000 Ada Generation Laptop GPU
[W1006 12:43:32.703334536 Context.cpp:320] Warning: torch.backends.cuda.preferred_linalg_library is an experimental feature. If you see any error or unexpected behavior when this flag is set please file an issue on GitHub. (function operator())
nb: 4
n:  6000
cpu timings:
batchEigh:    24.04 sec
poolEigh:     29.87 sec
cuda timings:
cusolver (preferred backend):
batchEigh:     7.42 sec
poolEigh:      6.65 sec
magma (preferred backend):
batchEigh:    10.87 sec
poolEigh:      8.18 sec

Symmetric eigendecomposition is an easier problem (although algorithms may also
be more mature and efficient), as evidenced by the cpu timings.

The cusolver implementation significantly – but not hugely – outperforms magma.

This adds to the argument that pytorch should use cusolver by default for the
general (asymmetric) linagl.eig().

(Based on Johannes’s reported timings for a non-pytorch cusolver general
eigendecompostion test vs. the pytorch / magma version, it appears that the
relative performance improvement of cusolver over magma would be even larger
for eig() (general) than for eigh() (symmetric).)

A couple of comments:

Again, parallelizing the batch dimension with a python thread pool increased
magma’s performance significantly (although not by as much as it did for eig().)

magma again used the cpu heavily, but did use the gpu more heavily (but far
from saturated) than it did for eig().

I don’t know how to inquire which backend is actually being used, but because
the timings and usage pattern differ depending on which backend is specified
in the torch.backends.cuda.preferred_linalg_library() call, I deduce that
cusolver and magma are being used according to which is chosen as “preferred.”

(Not relevant to cuda, but using the thread pool hurt the cpu performance for
eigh(), rather than helping it significantly as it did for eig().)

Of course, both the absolute and relative timings will depend on the hardware
and size of the test tensor.

Best.

K. Frank

1 Like