Hi Utkarsh!
Using brute force to constrain a trainable parameter tends not to work
nicely with gradient-descent optimization. I generally find it preferable
to smoothly map unconstrained parameters to dependent variables that
satisfy the desired constraints.
sigmoid()
smoothly maps the unconstrained range (-inf, inf)
to the
constrained range (0.0, 1.0)
.
How best to use these ideas to constrain your a + b + c
will depend on
your specific use case and what a
, b
, and c
are supposed to mean.
Let me offer a rather elaborate way to achieve this, based on some tacit
assumptions on how you want a
, b
, and c
to behave. The basic goals of
this particular scheme are to impose the joint constraint that a + b + c
lies in the range (0.0, 1.0)
, while otherwise leaving a
, b
, and c
as
unconstrained “as possible.” Also, we want to treat a
, b
, and c
on equal
footing, rather than, say, leaving a
and b
fully unconstrained and then
constraining c
such that the joint constraint is satisfied. These may or
may not be goals that are appropriate for your use case.
The idea is to start with three trainable parameters that are fully unconstrained
so that they work well with gradient-descent optimization and derive from
them three variables that satisfy the joint constraint that will then be used
for the rest of the forward pass.
To do so, we combine the three unconstrained parameters into a vector
(more precisely, stack()
them into a dimension of a multidimensional
tensor), rotate that vector so that the sum of the parameters becomes a
single component of the rotated vector, map that sum to (0.0, 1.0)
,
and then rotate back.
This smoothly and differentiably maps the unconstrained parameters to
their constrained versions and treats the three unconstrained parameters
all on the same footing.
Here is a script that implements this scheme:
import torch
print (torch.__version__)
_ = torch.manual_seed (2023)
a0 = torch.randn (64, 64, 1) # unconstrained version of a to train
b0 = torch.randn (64, 64, 1) # unconstrained version of b to train
c0 = torch.randn (64, 64, 1) # unconstrained version of c to train
print ('(a0 + b0 + c0).min():', (a0 + b0 + c0).min()) # a0, b0, and c0 violate joint constraint
print ('(a0 + b0 + c0).max():', (a0 + b0 + c0).max()) # a0, b0, and c0 violate joint constraint
m = torch.tensor ([
[ 1.0, 1.0, 1.0], # will form a + b + c
[ 0.0, 1.0, -1.0], # chosen to be orthogonal
[-2.0, 1.0, 1.0] # chosen to be orthogonal
])
m /= m.norm (dim = 1).unsqueeze (-1) # make m orthonormal
print ('m @ m.T = ...') # check that m is orthonormal
print (m @ m.T)
abc0 = torch.stack ((a0, b0, c0), dim = -1)
abc = abc0 @ m.T # rotate to "a + b + c" space
abcsum = abc[:, :, :, 0] # view of "a + b + c" term
abcmap = (3**0.5 * abcsum).sigmoid() / (3**0.5) # map a + b + c to range (0.0, 1.0)
abcsum.copy_ (abcmap) # copy mapped value back into "a + b + c" view
abc = abc @ m # rotate back to "abc" space
a, b, c = abc.unbind (dim = -1) # constrained a, b, and c tensors for forward pass
print ('(a + b + c).min():', (a + b + c).min()) # check that a, b, and c satisfy joint constraint
print ('(a + b + c).max():', (a + b + c).max()) # check that a, b, and c satisfy joint constraint
print ('(a + b).min():', (a + b).min()) # a, b, and c are "otherwise" unconstrained
print ('(a + b).max():', (a + b).max()) # a, b, and c are "otherwise" unconstrained
print ('(b + c).min():', (b + c).min()) # a, b, and c are "otherwise" unconstrained
print ('(b + c).max():', (b + c).max()) # a, b, and c are "otherwise" unconstrained
And here is its output:
2.1.0
(a0 + b0 + c0).min(): tensor(-6.3463)
(a0 + b0 + c0).max(): tensor(6.2076)
m @ m.T = ...
tensor([[1.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 1.0000]])
(a + b + c).min(): tensor(0.0018)
(a + b + c).max(): tensor(0.9980)
(a + b).min(): tensor(-2.4540)
(a + b).max(): tensor(3.1374)
(b + c).min(): tensor(-3.1173)
(b + c).max(): tensor(3.1664)
Best.
K. Frank