Pytorch diagonalization of Tri diagonal complex symmetric matrix

Hi all!

I am looking for a backprop-friendly algorithm in pytorch to find the eigenvectors and eigenvalues of a tridiagonal symmetric matrix (only the main and upper and lower diagonal different from zero).
The variables on which I am running the grad are the matrix elements and the loss is a backprop-friendly function of the eigenvalues and eigenvectors.

Any idea?
torch.eig and variations do not work since the grad flow does not pass through.

Thank you!

Fabio

Hi @Fabio_Anselmi,

Do you have a minimal reproducible example?

Also, the eigenvalues won’t have a derivative by definition so you can’t backprop through the eigenvalues, you should be able to backprop through the eigenvectors (I think).

Yes, here is my tentative to solve the problem with QR algortithm… but the results are not correct, I tried to test it with not backprop friendly algos and it gives different results :frowning:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import os
os.environ[‘KMP_DUPLICATE_LIB_OK’] = ‘True’

SEED = 467
torch.manual_seed(SEED)

#Constants of the system
n=4
Gamma = torch.rand(n)
Omega = torch.randn(n)# could be negative
g = torch.rand(n-1)

def tridiag_eig(g,Gamma,Omega):
n = Gamma.shape[0]
diag = -torch.relu(Gamma)/2 +1jOmega #impose positiveness constraints
upper_diag = -1j
torch.relu(g)
tridiag_mat = torch.diag(diag) + torch.diag(upper_diag,1) + torch.diag(upper_diag,-1)

# Initialize variables
eig_vals = tridiag_mat.new_empty(n)
left_eig_vecs = torch.eye(n, dtype=tridiag_mat.dtype, device=tridiag_mat.device)
right_eig_vecs = torch.eye(n, dtype=tridiag_mat.dtype, device=tridiag_mat.device)

# QR Algorithm
for i in range(100):
    # QR decomposition
    Q, R = torch.linalg.qr(tridiag_mat) 
    
    # Compute the new matrix
    tridiag_mat = torch.mm(R, Q)

    # Update eigenvectors
    left_eig_vecs = torch.mm(left_eig_vecs, Q)
    right_eig_vecs = torch.mm(right_eig_vecs, torch.inverse(Q))

    # Check convergence
    if torch.max(torch.abs(torch.triu(tridiag_mat, diagonal=1))) < 1e-6:
        break

# Extract eigenvalues
eig_vals = torch.diag(tridiag_mat)

return eig_vals, left_eig_vecs, right_eig_vecs

eig_vals, left_eig_vecs, right_eig_vecs = tridiag_eig(g,Gamma,Omega)

Hi Fabio!

torch.linalg.eig() works for me (as does torch.linalg.eigh()):

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> n = 5
>>>
>>> d_main = torch.randn (n)
>>> d_lower = torch.randn (n - 1)
>>> m_lower = torch.diag (d_main) + torch.diag (d_lower, -1)
>>> m_tri = m_lower + torch.diag (d_lower, 1)
>>>
>>> torch.linalg.eig (m_tri)                   # general algorithm
torch.return_types.linalg_eig(
eigenvalues=tensor([ 1.4005+0.j,  0.5639+0.j, -1.2217+0.j, -1.1171+0.j, -0.7208+0.j]),
eigenvectors=tensor([[ 0.0027+0.j, -0.0882+0.j, -0.9948+0.j,  0.0512+0.j,  0.0054+0.j],
        [ 0.0450+0.j, -0.9944+0.j,  0.0899+0.j,  0.0295+0.j,  0.0168+0.j],
        [ 0.5250+0.j, -0.0086+0.j, -0.0403+0.j, -0.7924+0.j, -0.3077+0.j],
        [ 0.8454+0.j,  0.0575+0.j,  0.0246+0.j,  0.5224+0.j,  0.0922+0.j],
        [ 0.0875+0.j,  0.0098+0.j, -0.0114+0.j, -0.3092+0.j,  0.9468+0.j]]))
>>>
>>> torch.linalg.eigh (m_lower)                # takes advantage of symmetry of matrix
torch.return_types.linalg_eigh(
eigenvalues=tensor([-1.2217, -1.1171, -0.7208,  0.5639,  1.4005]),
eigenvectors=tensor([[ 0.9948,  0.0512,  0.0054,  0.0882, -0.0027],
        [-0.0899,  0.0295,  0.0168,  0.9944, -0.0450],
        [ 0.0403, -0.7924, -0.3077,  0.0086, -0.5250],
        [-0.0246,  0.5224,  0.0922, -0.0575, -0.8454],
        [ 0.0114, -0.3092,  0.9468, -0.0098, -0.0875]]))
>>>
>>> m_tri.requires_grad = True
>>> loss = torch.linalg.eig (m_tri)[0][2]      # loss is a single eigenvalue
>>> loss.backward()
>>> m_tri.grad
tensor([[ 9.8957e-01, -8.9380e-02,  4.0040e-02, -2.4454e-02,  1.1315e-02],
        [-8.9380e-02,  8.0730e-03, -3.6165e-03,  2.2087e-03, -1.0220e-03],
        [ 4.0041e-02, -3.6166e-03,  1.6201e-03, -9.8947e-04,  4.5784e-04],
        [-2.4454e-02,  2.2088e-03, -9.8947e-04,  6.0430e-04, -2.7962e-04],
        [ 1.1315e-02, -1.0220e-03,  4.5784e-04, -2.7962e-04,  1.2938e-04]])
>>>
>>> m_tri.grad = None
>>> loss = torch.linalg.eig (m_tri)[1][1, 2]   # loss is a single element of a single eigenvector
>>> loss.backward()
>>> m_tri.grad
tensor([[ 0.0634, -0.0057,  0.0026, -0.0016,  0.0007],
        [ 0.5604, -0.0506,  0.0227, -0.0138,  0.0064],
        [-0.2186,  0.0197, -0.0088,  0.0054, -0.0025],
        [ 0.1320, -0.0119,  0.0053, -0.0033,  0.0015],
        [-0.0590,  0.0053, -0.0024,  0.0015, -0.0007]])

[Edit: I overlooked that you were asking about the complex case.]

The situation is similar for the complex case, but even in the (hermitian)
tri-diagonal case, you can now have complex eigenvectors, so the
eigenvector phase ambiguity makes the gradient of (phase-dependent)
functions of the eigenvectors not well defined. (Gradients of eigenvalues
are still fine.)

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> n = 5
>>>
>>> d_main = torch.randn (n, dtype = torch.complex64)
>>> d_lower = torch.randn (n - 1, dtype = torch.complex64)
>>> m_lower = torch.diag (d_main) + torch.diag (d_lower, -1)
>>> m_tri = m_lower + torch.diag (d_lower, 1).conj()
>>>
>>> torch.linalg.eig (m_tri)                   # general algorithm
torch.return_types.linalg_eig(
eigenvalues=tensor([-2.4665+0.3781j, -1.5699+0.4308j,  1.8091+0.4371j,  0.7244+0.4485j,
         0.0592+0.0190j]),
eigenvectors=tensor([[-4.1827e-01+0.1134j,  6.5241e-01+0.0000j, -1.4480e-01+0.1424j,
         -2.9093e-01+0.4490j,  9.6288e-02+0.1866j],
        [-4.0111e-01-0.3376j,  2.1858e-01+0.2747j,  3.9774e-01+0.0802j,
          6.3391e-01+0.0000j,  1.5024e-02-0.1544j],
        [ 6.0069e-01+0.0000j,  4.2588e-01+0.0718j,  4.8217e-01-0.3556j,
         -5.6944e-02+0.0620j,  6.8369e-04+0.2855j],
        [-3.9204e-01-0.1196j, -4.4602e-01-0.1976j,  6.2306e-01+0.0000j,
         -4.1192e-01+0.1344j, -7.1238e-03+0.1854j],
        [ 2.1511e-02+0.0899j,  4.8561e-04+0.1631j,  1.8057e-01+0.1198j,
         -3.3740e-01+0.0566j,  9.0328e-01+0.0000j]]))
>>>
>>> torch.linalg.eigh (m_lower)                # takes advantage of hermeticity of matrix
torch.return_types.linalg_eigh(
eigenvalues=tensor([-2.4809, -1.5880,  0.0173,  0.7716,  1.8363]),
eigenvectors=tensor([[-0.4259+0.0000j, -0.6590-0.0000j, -0.2628+0.0000j,  0.5261+0.0000j,
         -0.1965+0.0000j],
        [-0.2990-0.4251j, -0.2088-0.2968j,  0.0988+0.1405j, -0.3690-0.5247j,
          0.2281+0.3243j],
        [ 0.5983+0.0911j, -0.4138-0.0630j, -0.3100-0.0472j,  0.0325+0.0049j,
          0.5924+0.0902j],
        [-0.3422-0.2370j,  0.3956+0.2740j, -0.1670-0.1156j,  0.3280+0.2272j,
          0.5165+0.3577j],
        [ 0.0084+0.0943j, -0.0147-0.1647j,  0.0777+0.8694j,  0.0347+0.3877j,
          0.0200+0.2234j]]))
>>>
>>> m_tri.requires_grad = True
>>> loss = torch.linalg.eig (m_tri)[0][2]      # loss is a single eigenvalue
>>> loss.backward()                            # well defined for eigenvalues
>>> m_tri.grad
tensor([[ 0.0416+0.0051j, -0.0550+0.0633j, -0.1237+0.0026j, -0.1020+0.0785j,
         -0.0145+0.0424j],
        [-0.0409-0.0731j,  0.1669+0.0141j,  0.1500+0.1966j,  0.2469+0.0719j,
          0.0854-0.0266j],
        [-0.1189-0.0344j,  0.1900-0.1583j,  0.3616+0.0509j,  0.3340-0.1806j,
          0.0621-0.1166j],
        [-0.1094-0.0679j,  0.2347-0.1050j,  0.3477+0.1526j,  0.3852-0.0868j,
          0.0949-0.0992j],
        [ 0.0067-0.0443j,  0.0657+0.0606j, -0.0012+0.1321j,  0.0805+0.1113j,
          0.0447+0.0168j]])
>>>
>>> m_tri.grad = None
>>> loss = torch.linalg.eig (m_tri)[1][1, 2]   # loss is a single element of a single eigenvector
>>> loss.backward()                            # fails because of phase ambiguity in eigenvectors
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: linalg_eig_backward: The eigenvectors in the complex case are specified up to multiplication by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined.
>>> m_tri.grad
>>> 

Note that torch.eig() was deprecated and removed several versions ago.

From recollection, torch.eig() was not as smart about recognizing that
the eigenvectors were real so that the backward pass would be well defined.

Best.

K. Frank

Hi Frank, thank you for your answer! Can I fix somehow the phase to solve the ambiguity and reintroduce the grad?

Hi Fabio!

Without knowing the details of your use case, I can’t really speculate what
an appropriate fix might be.

In general, the right approach is to call grad() only on quantities that are
independent of the (ambiguous) phase of the eigenvectors. Quoting from
the documentation for eigh():

For this reason, the loss function shall not depend on the phase of the eigenvectors, as this quantity is not well-defined.

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> n = 5
>>>
>>> d_main = torch.randn (n, dtype = torch.complex64)
>>> d_lower = torch.randn (n - 1, dtype = torch.complex64)
>>> m_lower = torch.diag (d_main) + torch.diag (d_lower, -1)
>>> m_lower.requires_grad = True
>>> eigs = torch.linalg.eigh (m_lower)
>>> eigs
torch.return_types.linalg_eigh(
eigenvalues=tensor([-2.4809, -1.5880,  0.0173,  0.7716,  1.8363],
       grad_fn=<LinalgEighBackward0>),
eigenvectors=tensor([[-0.4259+0.0000j, -0.6590-0.0000j, -0.2628+0.0000j,  0.5261+0.0000j,
         -0.1965+0.0000j],
        [-0.2990-0.4251j, -0.2088-0.2968j,  0.0988+0.1405j, -0.3690-0.5247j,
          0.2281+0.3243j],
        [ 0.5983+0.0911j, -0.4138-0.0630j, -0.3100-0.0472j,  0.0325+0.0049j,
          0.5924+0.0902j],
        [-0.3422-0.2370j,  0.3956+0.2740j, -0.1670-0.1156j,  0.3280+0.2272j,
          0.5165+0.3577j],
        [ 0.0084+0.0943j, -0.0147-0.1647j,  0.0777+0.8694j,  0.0347+0.3877j,
          0.0200+0.2234j]], grad_fn=<LinalgEighBackward0>))
>>> loss = eigs[1][1, 2].abs()                 # loss is a phase-independent function of the eigenvectors
>>> loss.backward()                            # backward() is well defined
>>> m_lower.grad
tensor([[-0.1912+5.1455e-10j,  0.0693-9.8580e-02j, -0.0954+1.4522e-02j,
         -0.0802+5.5560e-02j,  0.0248-2.7710e-01j],
        [ 0.0693+9.8580e-02j, -0.0758+0.0000e+00j,  0.0400+4.1699e-02j,
          0.0566+2.0807e-02j, -0.1446+8.3531e-02j],
        [-0.0954-1.4522e-02j,  0.0400-4.1699e-02j,  0.0420+4.6566e-10j,
         -0.0132+6.4294e-03j, -0.0251+1.0246e-01j],
        [-0.0802-5.5560e-02j,  0.0566-2.0807e-02j, -0.0132-6.4294e-03j,
         -0.0366+3.7253e-09j,  0.0309-3.7008e-02j],
        [ 0.0248+2.7710e-01j, -0.1446-8.3531e-02j, -0.0251-1.0246e-01j,
          0.0309+3.7008e-02j,  0.2617+7.4506e-09j]])

Best.

K. Frank

1 Like

Thank you, I see.
I suspect I do not have a way around the problem.

My loss, as below, is a function of the right (U) left (V) eigenvectors and the eigenvalues and a fixed TCorr function, that I approximate :

class LossCorr(nn.Module):
def init(self):
super(LossCorr, self).init()
self.Gamma = nn.Parameter(torch.rand((n), dtype=torch.float32), requires_grad=True)
self.Omega = nn.Parameter(torch.rand((n), dtype=torch.float32), requires_grad=True)#could be negative
self.g = nn.Parameter(torch.rand((n-1), dtype=torch.float32), requires_grad=True)
self.C = nn.Parameter(torch.rand((n,n), dtype=torch.complex64), requires_grad=True)# C can be complex

def forward(self):
    W=torch.zeros((n,n,n))
    E,V,U = Diagonalization_algo(self.g,self.Gamma,self.Omega) #Calculate the eigenvectors and eigenvalues
    
    for i in range(n):
        UV = torch.mm(U[:,i].unsqueeze(dim=0).T,V[:,i].unsqueeze(dim=0))
        W[i,:,:] = (self.C.T@UV)@self.C.conj()
    
    Cor=torch.zeros((n,n),dtype=torch.complex64)
    
    loss=0
    for t in range(Tmax):
        for i in range(n):
             Cor += W[i,:,:]*torch.exp(E[i]*t)
          
        loss += (1/Tmax)*torch.norm(Cor-TCorr(t),2)        
     
    return loss,Gamma,Omega,g

Hi Fabio!

I’m not sure that I see your problem. Based on what you have posted, it
looks like C is the only complex entity and it does not appear to participate
in Diagonalization_algo(). (As an aside, why not use something provided
by pytorch such as torch.linalg.eig() instead of writing your own?)

The code you posted before suggests that you are diagonalizing a real,
symmetric, tri-diagonal matrix. You should therefore have real eigenvectors
(as well as eigenvalues), so you shouldn’t have any phase ambiguity, and
backpropagation should work fine.

Best.

K. Frank

Hi Frank!

Gamma Omega and g are real but the matrix defined from them is complex:

def tridiag_eig(g,Gamma,Omega):
n = Gamma.shape[0]
diag = -torch.relu(Gamma)/2 +1j*Omega #impose positiveness constraints
upper_diag = -1j *torch.relu(g)
tridiag_mat = torch.diag(diag) + torch.diag(upper_diag,1) + torch.diag(upper_diag,-1)