Hi Sourya!
This is not correct. Backpropagation / gradient computation can, in
appropriate cases, occur through torch.linalg.eig()
where the
eigenvectors / eigenvalues have non-zero imaginary parts. (The code
you linked to that checks for an imaginary part tests an intermediate
result, not the eigenvectors of your Ktilde
.)
The short answer is don’t try to compute gradients of quantities that depend
on the phases of your eigenvectors. (Pytorch does support computing
gradients of phase-independent quantities.)
There is not a good way around this. At issue is that the warning in the
documentation and the RuntimeError
you cited in your original post
are fully legitimate. If you get that RuntimeError
you are almost certainly
doing something that doesn’t make sense.
The problem – as mentioned in the documentation warning – is that the
phases of the eigenvectors are not mathematically uniquely defined. So
a loss function that depends on those phases, as well as the gradient of
such a loss function, is also not uniquely defined.
Pytorch could, hypothetically, impose an ad hoc set of rules that uniquely
determines those phases, but that would just give you ad hoc uniqueness
that would still be mathematically arbitrary, so you’d be sweeping your
issue under the rug, rather than actually fixing it. (Also, if you look at my
example, below, you can see that setting up those ad hoc rules could be
tricky, because there are multiple, mathematically-equivalent paths to
computing the same eigenvectors.)
The following example illustrates what is going on. As a device to generate
eigenvectors with differing phases, I apply an orthogonal transformation
to the original real square matrix, compute its eigenvectors, and then
rotate those eigenvectors back to the basis of the original matrix using
the same orthogonal transformation. You can understand this as a second,
fully mathematically legitimate algorithm for computing the eigenvectors.
To make clear that pytorch version 1.12 does permit you to backpropagate
through eig()
provided you are computing the gradient of a
phase-independent quantity, I post the example as run in both versions
1.10 and 1.12. (These two example runs are almost identical duplicates
of one another, so there’s no need to compare them line by line. The
only difference is at the end where the 1.12 version flags the illegitimate
backpropagation.)
Here is the 1.10 version:
>>> import torch
>>> print (torch.__version__)
1.10.2
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> tA = torch.randn (3, 3)
>>> tB = tA.clone()
>>> tA.requires_grad = True
>>> tB.requires_grad = True
>>>
>>> valA, vecA = torch.linalg.eig (tA)
>>>
>>> orth = torch.linalg.svd (torch.randn (3, 3))[0] # orthogonal matrix to rotate tB
>>> tBr = orth @ tB @ orth.T
>>> valBr, vecBr = torch.linalg.eig (tBr) # eigenvectors of rotated tB
>>>
>>> tAc = tA.to (dtype = torch.complex64)
>>> orthc = orth.to (dtype = torch.complex64)
>>>
>>> vecB = orthc.T @ vecBr # rotate eigenvectors back to basis of tB
>>> torch.allclose (tAc @ vecB, valA * vecB) # check that vecB is indeed a set of eigenvectors of tA
True
>>>
>>> vecB / vecA # each eigenvector is changed by a complex phase
tensor([[ 0.9991+0.0414j, 0.9991-0.0414j, -1.0000-0.0000j],
[ 0.9991+0.0414j, 0.9991-0.0414j, -1.0000-0.0000j],
[ 0.9991+0.0414j, 0.9991-0.0414j, -1.0000+0.0000j]],
grad_fn=<DivBackward0>)
>>>
>>> lossGoodA = (vecA * vecA.conj())[0, 0]
>>> lossGoodB = (vecB * vecB.conj())[0, 0]
>>>
>>> lossGoodA
tensor(0.0332+0.j, grad_fn=<SelectBackward0>)
>>> lossGoodB
tensor(0.0332+0.j, grad_fn=<SelectBackward0>)
>>> torch.allclose (lossGoodA, lossGoodB) # phase-independent loss is the same for both sets of eigenvectors
True
>>>
>>> lossGoodA.backward() # both versions 1.10 and 1.12 permit legitimate backward pass
>>> tA.grad
tensor([[-0.0096, 0.4561, -0.3340],
[ 0.0084, -0.0616, 0.0662],
[-0.0009, -0.1106, 0.0712]])
>>>
>>> lossBadA = (vecA * vecA).real[0, 0]
>>> lossBadB = (vecB * vecB).real[0, 0]
>>>
>>> lossBadA
tensor(0.0198, grad_fn=<SelectBackward0>)
>>> lossBadB
tensor(0.0176, grad_fn=<SelectBackward0>)
>>> torch.allclose (lossBadA, lossBadB) # this loss is not phase-independent
False
>>>
>>> lossBadB.backward() # version 1.12 flags this backward pass as ill-defined
>>> tB.grad # this gradient is phase-dependent (in version 1.10 and not computed in version 1.12)
tensor([[ 0.0844, 0.2666, 0.0732],
[-0.0095, -0.0822, 0.0267],
[-0.0213, -0.0918, -0.0022]])
And here is the 1.12 version:
>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> tA = torch.randn (3, 3)
>>> tB = tA.clone()
>>> tA.requires_grad = True
>>> tB.requires_grad = True
>>>
>>> valA, vecA = torch.linalg.eig (tA)
>>>
>>> orth = torch.linalg.svd (torch.randn (3, 3))[0] # orthogonal matrix to rotate tB
>>> tBr = orth @ tB @ orth.T
>>> valBr, vecBr = torch.linalg.eig (tBr) # eigenvectors of rotated tB
>>>
>>> tAc = tA.to (dtype = torch.complex64)
>>> orthc = orth.to (dtype = torch.complex64)
>>>
>>> vecB = orthc.T @ vecBr # rotate eigenvectors back to basis of tB
>>> torch.allclose (tAc @ vecB, valA * vecB) # check that vecB is indeed a set of eigenvectors of tA
True
>>>
>>> vecB / vecA # each eigenvector is changed by a complex phase
tensor([[ 0.9991+0.0414j, 0.9991-0.0414j, -1.0000-0.0000j],
[ 0.9991+0.0414j, 0.9991-0.0414j, -1.0000-0.0000j],
[ 0.9991+0.0414j, 0.9991-0.0414j, -1.0000+0.0000j]],
grad_fn=<DivBackward0>)
>>>
>>> lossGoodA = (vecA * vecA.conj())[0, 0]
>>> lossGoodB = (vecB * vecB.conj())[0, 0]
>>>
>>> lossGoodA
tensor(0.0332+0.j, grad_fn=<SelectBackward0>)
>>> lossGoodB
tensor(0.0332+0.j, grad_fn=<SelectBackward0>)
>>> torch.allclose (lossGoodA, lossGoodB) # phase-independent loss is the same for both sets of eigenvectors
True
>>>
>>> lossGoodA.backward() # both versions 1.10 and 1.12 permit legitimate backward pass
>>> tA.grad
tensor([[-0.0096, 0.4561, -0.3340],
[ 0.0084, -0.0616, 0.0662],
[-0.0009, -0.1106, 0.0712]])
>>>
>>> lossBadA = (vecA * vecA).real[0, 0]
>>> lossBadB = (vecB * vecB).real[0, 0]
>>>
>>> lossBadA
tensor(0.0198, grad_fn=<SelectBackward0>)
>>> lossBadB
tensor(0.0176, grad_fn=<SelectBackward0>)
>>> torch.allclose (lossBadA, lossBadB) # this loss is not phase-independent
False
>>>
>>> lossBadB.backward() # version 1.12 flags this backward pass as ill-defined
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<path_to_pytorch_install>\torch\_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: linalg_eig_backward: The eigenvectors in the complex case are specified up to multiplication by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined.
>>> tB.grad # this gradient is phase-dependent (in version 1.10 and not computed in version 1.12)
>>>
Just to be clear, the RuntimeError
raised in version >= 1.11 isn’t breaking
your code – it’s helpfully warning you that your code is already broken.
Best.
K. Frank