Hum, so let me ask a couple of questions to get more intuition on your problem.
By doing r_eigenvalues = torch.where(~torch.isreal(eigenvalues), 1.0, eigenvalues.real)
are you assuming that the real eigenvalue is strictly less than 1.0
? Because if it is not, the min
call in the subsequent line will not pick an eigenvalue but the 1.0
you added with where.
And do you really care only about the real eigenvalues? Or would it also work for minimum real part of all eigenvalues or the minimum singular value?
Well, if your answers are “yes” for both, I think you could do something like
import torch
B = 2
N = 3
leaf = torch.ones(1, requires_grad=True) # just to create a graph
x = torch.tensor([-2.0, 1.0]) * leaf
y = (
torch.tensor(
[
[
[-0.2671, -0.0748, -0.1718],
[-0.1510, 0.2960, -1.6154],
[-0.0786, -1.6944, 0.9237],
],
[
[-0.7459, -1.0099, -0.2937],
[-0.4803, -0.8911, -0.3678],
[0.3982, -0.7645, 0.7407],
],
]
)
* leaf
)
eigenvalues = torch.linalg.eigvals(y)
r_eigenvalues = torch.where(~torch.isreal(eigenvalues), torch.inf, eigenvalues.real)
min_eigenvalues, _ = r_eigenvalues.min(-1)
softplus = torch.nn.Softplus(beta=5.0)
x_cond = (softplus(1 + x.mul(min_eigenvalues)) - 1) / min_eigenvalues
print(1 + x * min_eigenvalues) # [ 3.2149, -0.5520] <- unconstrained x violates condition
print(1 + x_cond * min_eigenvalues) # [3.2149, 0.0123] <- constrained x_cond doesn't
x_cond[1].backward()
print(leaf.grad) # got a non-zero grad
Then you use x_cond
instead of x
in whatever downstream operation you perform. Note that the constraint is not tightly active due to smoothness of Softplus. You can make it tighter by increasing beta but you may run into vanishing gradients. An alternative would be to use a LeakyReLU instead of Softplus, but the constraint may be slightly violated.