Applying a constraint on tensors during optimization

I am optimizing three tensors, a, b, and c, of shape (64, 64, 1) each.

During the forward pass, I want their values such that the following condition holds:

0<= a + b + c <= 1

I don’t know how to update a, b, and c such that the above condition holds true.

Thanks in advance for any help and guidance.

Hi Utkarsh!

Using brute force to constrain a trainable parameter tends not to work
nicely with gradient-descent optimization. I generally find it preferable
to smoothly map unconstrained parameters to dependent variables that
satisfy the desired constraints.

sigmoid() smoothly maps the unconstrained range (-inf, inf) to the
constrained range (0.0, 1.0).

How best to use these ideas to constrain your a + b + c will depend on
your specific use case and what a, b, and c are supposed to mean.

Let me offer a rather elaborate way to achieve this, based on some tacit
assumptions on how you want a, b, and c to behave. The basic goals of
this particular scheme are to impose the joint constraint that a + b + c
lies in the range (0.0, 1.0), while otherwise leaving a, b, and c as
unconstrained “as possible.” Also, we want to treat a, b, and c on equal
footing, rather than, say, leaving a and b fully unconstrained and then
constraining c such that the joint constraint is satisfied. These may or
may not be goals that are appropriate for your use case.

The idea is to start with three trainable parameters that are fully unconstrained
so that they work well with gradient-descent optimization and derive from
them three variables that satisfy the joint constraint that will then be used
for the rest of the forward pass.

To do so, we combine the three unconstrained parameters into a vector
(more precisely, stack() them into a dimension of a multidimensional
tensor), rotate that vector so that the sum of the parameters becomes a
single component of the rotated vector, map that sum to (0.0, 1.0),
and then rotate back.

This smoothly and differentiably maps the unconstrained parameters to
their constrained versions and treats the three unconstrained parameters
all on the same footing.

Here is a script that implements this scheme:

import torch
print (torch.__version__)

_ = torch.manual_seed (2023)

a0 = torch.randn (64, 64, 1)                            # unconstrained version of a to train
b0 = torch.randn (64, 64, 1)                            # unconstrained version of b to train
c0 = torch.randn (64, 64, 1)                            # unconstrained version of c to train

print ('(a0 + b0 + c0).min():', (a0 + b0 + c0).min())   # a0, b0, and c0 violate joint constraint
print ('(a0 + b0 + c0).max():', (a0 + b0 + c0).max())   # a0, b0, and c0 violate joint constraint

m = torch.tensor ([
    [ 1.0,  1.0,  1.0],                                 # will form  a + b + c
    [ 0.0,  1.0, -1.0],                                 # chosen to be orthogonal
    [-2.0,  1.0,  1.0]                                  # chosen to be orthogonal

m /= m.norm (dim = 1).unsqueeze (-1)                    # make m orthonormal

print ('m @ m.T = ...')                                 # check that m is orthonormal
print (m @ m.T)

abc0 = torch.stack ((a0, b0, c0), dim = -1)

abc = abc0 @ m.T                                        # rotate to "a + b + c" space
abcsum = abc[:, :, :, 0]                                # view of "a + b + c" term
abcmap = (3**0.5 * abcsum).sigmoid() / (3**0.5)         # map  a + b + c  to range (0.0, 1.0)
abcsum.copy_ (abcmap)                                   # copy mapped value back into "a + b + c" view
abc = abc @ m                                           # rotate back to "abc" space

a, b, c = abc.unbind (dim = -1)                         # constrained a, b, and c tensors for forward pass

print ('(a + b + c).min():', (a + b + c).min())         # check that a, b, and c satisfy joint constraint
print ('(a + b + c).max():', (a + b + c).max())         # check that a, b, and c satisfy joint constraint

print ('(a + b).min():', (a + b).min())                 # a, b, and c are "otherwise" unconstrained
print ('(a + b).max():', (a + b).max())                 # a, b, and c are "otherwise" unconstrained
print ('(b + c).min():', (b + c).min())                 # a, b, and c are "otherwise" unconstrained
print ('(b + c).max():', (b + c).max())                 # a, b, and c are "otherwise" unconstrained

And here is its output:

(a0 + b0 + c0).min(): tensor(-6.3463)
(a0 + b0 + c0).max(): tensor(6.2076)
m @ m.T = ...
tensor([[1.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 0.0000, 1.0000]])
(a + b + c).min(): tensor(0.0018)
(a + b + c).max(): tensor(0.9980)
(a + b).min(): tensor(-2.4540)
(a + b).max(): tensor(3.1374)
(b + c).min(): tensor(-3.1173)
(b + c).max(): tensor(3.1664)


K. Frank

If you can turn your constraint into an equality constraint, you can use Lagrange multiplier - Wikipedia. To turn an inequality to an equality constraint, one strategy is to use Slack variable - Wikipedia. For example, you can check how slack variables are used in the simplex method: Introduction to the simplex method.