Clamping to satisfy an inequality without messing up grads

Hello!

I have two tensors, x (B,) and y (B, N, N). Each matrix in y is theoretically guaranteed to have a real eigenvalue. I require x to always satisfy 1 + x.mul(minimum_real_eigenvalue(y)) > 0.

Basically, I’m going to use x_cond = x.clamp(...) or x_cond = torch.min(...) in downstream tasks.

Any ideas? Thank you!

clamp and min will make your gradient vanish, indeed.

Maybe Softplus (with a beta>5) would be a better alternative?

What I am currently doing is something like this:

eigenvalues = torch.linalg.eigvals(y)
r_eigenvalues = torch.where(~torch.isreal(eigenvalues), 1.0, eigenvalues.real)
min_eigenvalues = r_eigenvalues.min(-1).values
x_cond = torch.where(x.mul(min_eigenvalues).add(1) > 0,
                     x,
                     min_eigenvalues.mul(-1).pow(-1) - 1e-5)

Meaning that I am not exactly using min or clamp. My problem is that when I introduce this into the code instead of leaving x (which is a learnable Parameter) unbounded, I get an exploding gradient problem, even though when I inspect the values of x I find them reasonable (on the order of 1e-1 or 1e-2).

How would you introduce Softplus here?

Hum, so let me ask a couple of questions to get more intuition on your problem.

By doing r_eigenvalues = torch.where(~torch.isreal(eigenvalues), 1.0, eigenvalues.real) are you assuming that the real eigenvalue is strictly less than 1.0? Because if it is not, the min call in the subsequent line will not pick an eigenvalue but the 1.0 you added with where.

And do you really care only about the real eigenvalues? Or would it also work for minimum real part of all eigenvalues or the minimum singular value?

Well, if your answers are “yes” for both, I think you could do something like

import torch


B = 2
N = 3

leaf = torch.ones(1, requires_grad=True)  # just to create a graph

x = torch.tensor([-2.0, 1.0]) * leaf
y = (
    torch.tensor(
        [
            [
                [-0.2671, -0.0748, -0.1718],
                [-0.1510, 0.2960, -1.6154],
                [-0.0786, -1.6944, 0.9237],
            ],
            [
                [-0.7459, -1.0099, -0.2937],
                [-0.4803, -0.8911, -0.3678],
                [0.3982, -0.7645, 0.7407],
            ],
        ]
    )
    * leaf
)


eigenvalues = torch.linalg.eigvals(y)
r_eigenvalues = torch.where(~torch.isreal(eigenvalues), torch.inf, eigenvalues.real)
min_eigenvalues, _ = r_eigenvalues.min(-1)

softplus = torch.nn.Softplus(beta=5.0)

x_cond = (softplus(1 + x.mul(min_eigenvalues)) - 1) / min_eigenvalues

print(1 + x * min_eigenvalues)  # [ 3.2149, -0.5520] <- unconstrained x violates condition
print(1 + x_cond * min_eigenvalues)  # [3.2149, 0.0123] <- constrained x_cond doesn't

x_cond[1].backward()
print(leaf.grad)  # got a non-zero grad

Then you use x_cond instead of x in whatever downstream operation you perform. Note that the constraint is not tightly active due to smoothness of Softplus. You can make it tighter by increasing beta but you may run into vanishing gradients. An alternative would be to use a LeakyReLU instead of Softplus, but the constraint may be slightly violated.

Thanks lucas! Using softplus as a max proxy is a neat trick.

My choice of 1 in the first torch.where rested on the fact that any real eigenvalue > 1 will pass the test, since x > 0 always, which, in hindsight, I should’ve specified in the OP! I also wanted the subsequent 1 / min_eigenvalues to not yield any weird nan or inf (had I chosen 0 instead of 1) that could mess up the grads, even though they would be in the discarded section of torch.where. Paranoia, basically. :joy: