Backward Pass nan for Repeated Eigenvalues

Hi, I am seeing an issue on the backward pass when using torch.linalg.eigh on a hermitian matrix with repeated eigenvalues. I was wondering if there is any way to obtain the eigenvector associated with the minimum eigenvalue without the gradients in the backward pass going to nan. I am performing this calculation as a part of the loss function and there are no learnable parameters after this point.

I have set up the following testcase that illustrates the issue I am having, with the gradients on the torch.linalg.eigh calculation causing issues. I am using pytorch version 1.9

import torch

matrix = torch.tensor([[
            1.2039, 0, 0.3690, 0
            0, -1.2039, 0, 0.3690
            0.3690, 0, -1.2039, 0
            0, 0.3690, 0, 1.2039
    ]], requires_grad=True)

e, v = torch.linalg.eigh(matrix)
min_index = torch.argmin(e)
min_eigenvector = v[:, min_index]
min_state = min_eigenvector
min_sum = torch.sum(min_state) # Sum to obtain a scalar for .backward()
assert not matrix.grad.isnan().any()
1 Like

Hi Joe!

The short answer: No, there is not.

The issue is that in the presence of repeated (degenerate)
eigenvalues, “the eigenvector associated with” a degenerate
eigenvalue is not well defined.

An eigendecomposition algorithm (such as torch.linalg.eigh())
will give you a list of eigenvalues and associated eigenvectors. But
the two eigenvectors associated with two equal eigenvalues aren’t
well defined (only the two-dimensional subspace spanned by those
two eigenvectors is) – you can mix the two eigenvectors together
to get a different, but equally valid, pair of eigenvectors.

As a result, backwarding through such an eigenvector will give you
ill-defined results that will depend on the particulars of the algorithm
used and that can be nan. (Backwarding through the eigenvalues,
however, is perfectly fine.)

Quoting from the documentation for torch.linalg.eigh():

Gradients computed using the eigenvectors tensor will only be finite when A has unique eigenvalues. Furthermore, if the distance between any two eigvalues is close to zero, the gradient will be numerically unstable, as it depends on the eigenvalues λ_i​ through the computation of 1 / min _i≠j (λ_i − λ_j).

(And, furthermore, because the eigenvectors are not well defined,
your example loss is also not well defined.)

The following script illustrates the arbitrariness of the choice of
eigenvectors within the subspace defined by a set of degenerate
eigenvalues (and hence of your loss function):

import torch
print (torch.__version__)

matrix = torch.tensor([   # no leading singleton dimension
            1.2039, 0, 0.3690, 0
            0, -1.2039, 0, 0.3690
            0.3690, 0, -1.2039, 0
            0, 0.3690, 0, 1.2039

e, v = torch.linalg.eigh (matrix)

print ('e =', e)

vChk = torch.allclose (v @ v.T, torch.eye (4), atol = 1.e-6)
print ('check that v is orthonormal:', vChk)

eChk = torch.allclose (v @ torch.diag (e) @ v.T, matrix, atol = 1.e-6)
print ('check that e, v is an eigendecomposition:', eChk)

sqh = 0.5**0.5

rot = torch.tensor ([
    [ sqh, -sqh,  0.0,  0.0 ],
    [ sqh,  sqh,  0.0,  0.0 ],
    [ 0.0,  0.0,  1.0,  0.0 ],
    [ 0.0,  0.0,  0.0,  1.0 ]

vr = v @ rot

vrChk = torch.allclose (vr @ vr.T, torch.eye (4), atol = 1.e-6)
print ('check that vr is orthonormal:', vrChk)

erChk = torch.allclose (vr @ torch.diag (e) @ vr.T, matrix, atol = 1.e-6)
print ('check that e, vr is an eigendecomposition:', erChk)

loss_v = v[:, torch.argmin (e)].sum()
loss_vr = vr[:, torch.argmin (e)].sum()

print ('loss_v  =', loss_v)
print ('loss_vr =', loss_vr)

This script uses your matrix as its example, but with the leading
singleton dimension removed for simplicity.

Here is its output:

e = tensor([-1.2592, -1.2592,  1.2592,  1.2592])
check that v is orthonormal: True
check that e, v is an eigendecomposition: True
check that vr is orthonormal: True
check that e, vr is an eigendecomposition: True
loss_v  = tensor(0.8408)
loss_vr = tensor(1.1891)


K. Frank


Thank you for the help and the clear jsutification. Not the answer I was hoping for but very useful nonetheless!