Hi Fabio!
torch.linalg.eig()
works for me (as does torch.linalg.eigh()
):
>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> n = 5
>>>
>>> d_main = torch.randn (n)
>>> d_lower = torch.randn (n - 1)
>>> m_lower = torch.diag (d_main) + torch.diag (d_lower, -1)
>>> m_tri = m_lower + torch.diag (d_lower, 1)
>>>
>>> torch.linalg.eig (m_tri) # general algorithm
torch.return_types.linalg_eig(
eigenvalues=tensor([ 1.4005+0.j, 0.5639+0.j, -1.2217+0.j, -1.1171+0.j, -0.7208+0.j]),
eigenvectors=tensor([[ 0.0027+0.j, -0.0882+0.j, -0.9948+0.j, 0.0512+0.j, 0.0054+0.j],
[ 0.0450+0.j, -0.9944+0.j, 0.0899+0.j, 0.0295+0.j, 0.0168+0.j],
[ 0.5250+0.j, -0.0086+0.j, -0.0403+0.j, -0.7924+0.j, -0.3077+0.j],
[ 0.8454+0.j, 0.0575+0.j, 0.0246+0.j, 0.5224+0.j, 0.0922+0.j],
[ 0.0875+0.j, 0.0098+0.j, -0.0114+0.j, -0.3092+0.j, 0.9468+0.j]]))
>>>
>>> torch.linalg.eigh (m_lower) # takes advantage of symmetry of matrix
torch.return_types.linalg_eigh(
eigenvalues=tensor([-1.2217, -1.1171, -0.7208, 0.5639, 1.4005]),
eigenvectors=tensor([[ 0.9948, 0.0512, 0.0054, 0.0882, -0.0027],
[-0.0899, 0.0295, 0.0168, 0.9944, -0.0450],
[ 0.0403, -0.7924, -0.3077, 0.0086, -0.5250],
[-0.0246, 0.5224, 0.0922, -0.0575, -0.8454],
[ 0.0114, -0.3092, 0.9468, -0.0098, -0.0875]]))
>>>
>>> m_tri.requires_grad = True
>>> loss = torch.linalg.eig (m_tri)[0][2] # loss is a single eigenvalue
>>> loss.backward()
>>> m_tri.grad
tensor([[ 9.8957e-01, -8.9380e-02, 4.0040e-02, -2.4454e-02, 1.1315e-02],
[-8.9380e-02, 8.0730e-03, -3.6165e-03, 2.2087e-03, -1.0220e-03],
[ 4.0041e-02, -3.6166e-03, 1.6201e-03, -9.8947e-04, 4.5784e-04],
[-2.4454e-02, 2.2088e-03, -9.8947e-04, 6.0430e-04, -2.7962e-04],
[ 1.1315e-02, -1.0220e-03, 4.5784e-04, -2.7962e-04, 1.2938e-04]])
>>>
>>> m_tri.grad = None
>>> loss = torch.linalg.eig (m_tri)[1][1, 2] # loss is a single element of a single eigenvector
>>> loss.backward()
>>> m_tri.grad
tensor([[ 0.0634, -0.0057, 0.0026, -0.0016, 0.0007],
[ 0.5604, -0.0506, 0.0227, -0.0138, 0.0064],
[-0.2186, 0.0197, -0.0088, 0.0054, -0.0025],
[ 0.1320, -0.0119, 0.0053, -0.0033, 0.0015],
[-0.0590, 0.0053, -0.0024, 0.0015, -0.0007]])
[Edit: I overlooked that you were asking about the complex case.]
The situation is similar for the complex case, but even in the (hermitian)
tri-diagonal case, you can now have complex eigenvectors, so the
eigenvector phase ambiguity makes the gradient of (phase-dependent)
functions of the eigenvectors not well defined. (Gradients of eigenvalues
are still fine.)
Consider:
>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> _ = torch.manual_seed (2023)
>>>
>>> n = 5
>>>
>>> d_main = torch.randn (n, dtype = torch.complex64)
>>> d_lower = torch.randn (n - 1, dtype = torch.complex64)
>>> m_lower = torch.diag (d_main) + torch.diag (d_lower, -1)
>>> m_tri = m_lower + torch.diag (d_lower, 1).conj()
>>>
>>> torch.linalg.eig (m_tri) # general algorithm
torch.return_types.linalg_eig(
eigenvalues=tensor([-2.4665+0.3781j, -1.5699+0.4308j, 1.8091+0.4371j, 0.7244+0.4485j,
0.0592+0.0190j]),
eigenvectors=tensor([[-4.1827e-01+0.1134j, 6.5241e-01+0.0000j, -1.4480e-01+0.1424j,
-2.9093e-01+0.4490j, 9.6288e-02+0.1866j],
[-4.0111e-01-0.3376j, 2.1858e-01+0.2747j, 3.9774e-01+0.0802j,
6.3391e-01+0.0000j, 1.5024e-02-0.1544j],
[ 6.0069e-01+0.0000j, 4.2588e-01+0.0718j, 4.8217e-01-0.3556j,
-5.6944e-02+0.0620j, 6.8369e-04+0.2855j],
[-3.9204e-01-0.1196j, -4.4602e-01-0.1976j, 6.2306e-01+0.0000j,
-4.1192e-01+0.1344j, -7.1238e-03+0.1854j],
[ 2.1511e-02+0.0899j, 4.8561e-04+0.1631j, 1.8057e-01+0.1198j,
-3.3740e-01+0.0566j, 9.0328e-01+0.0000j]]))
>>>
>>> torch.linalg.eigh (m_lower) # takes advantage of hermeticity of matrix
torch.return_types.linalg_eigh(
eigenvalues=tensor([-2.4809, -1.5880, 0.0173, 0.7716, 1.8363]),
eigenvectors=tensor([[-0.4259+0.0000j, -0.6590-0.0000j, -0.2628+0.0000j, 0.5261+0.0000j,
-0.1965+0.0000j],
[-0.2990-0.4251j, -0.2088-0.2968j, 0.0988+0.1405j, -0.3690-0.5247j,
0.2281+0.3243j],
[ 0.5983+0.0911j, -0.4138-0.0630j, -0.3100-0.0472j, 0.0325+0.0049j,
0.5924+0.0902j],
[-0.3422-0.2370j, 0.3956+0.2740j, -0.1670-0.1156j, 0.3280+0.2272j,
0.5165+0.3577j],
[ 0.0084+0.0943j, -0.0147-0.1647j, 0.0777+0.8694j, 0.0347+0.3877j,
0.0200+0.2234j]]))
>>>
>>> m_tri.requires_grad = True
>>> loss = torch.linalg.eig (m_tri)[0][2] # loss is a single eigenvalue
>>> loss.backward() # well defined for eigenvalues
>>> m_tri.grad
tensor([[ 0.0416+0.0051j, -0.0550+0.0633j, -0.1237+0.0026j, -0.1020+0.0785j,
-0.0145+0.0424j],
[-0.0409-0.0731j, 0.1669+0.0141j, 0.1500+0.1966j, 0.2469+0.0719j,
0.0854-0.0266j],
[-0.1189-0.0344j, 0.1900-0.1583j, 0.3616+0.0509j, 0.3340-0.1806j,
0.0621-0.1166j],
[-0.1094-0.0679j, 0.2347-0.1050j, 0.3477+0.1526j, 0.3852-0.0868j,
0.0949-0.0992j],
[ 0.0067-0.0443j, 0.0657+0.0606j, -0.0012+0.1321j, 0.0805+0.1113j,
0.0447+0.0168j]])
>>>
>>> m_tri.grad = None
>>> loss = torch.linalg.eig (m_tri)[1][1, 2] # loss is a single element of a single eigenvector
>>> loss.backward() # fails because of phase ambiguity in eigenvectors
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<path_to_pytorch_install>\torch\_tensor.py", line 487, in backward
torch.autograd.backward(
File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: linalg_eig_backward: The eigenvectors in the complex case are specified up to multiplication by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined.
>>> m_tri.grad
>>>
Note that torch.eig()
was deprecated and removed several versions ago.
From recollection, torch.eig()
was not as smart about recognizing that
the eigenvectors were real so that the backward pass would be well defined.
Best.
K. Frank