Hi Hnakao!
This can certainly be done.
The most natural way to represent a piecewise linear function is as a
linear-interpolation look-up table. (If each linear piece is non-decreasing,
the overall function will also be non-decreasing.)
One concrete approach is illustrated in the below script. The basic ideas
are as follows:
Use torch.bucketize to locate input
within the look-up table.
Instead of storing the values of the function at the linear-interpolation
“break-point” boundaries, store the log of the differences in the values
of the function from one break-point to the next. You then get the
difference by exponentiating the log, thereby insuring that the function
is non-decreasing.
(It is possible that it will be easier to train the function differences (slopes)
rather than the logs, but then you would have to impose a positivity
constraint on the differences during training, which can introduce its own
issues.)
In my example, the log-differences (log-slopes) are trainable, but the
break-point boundaries are fixed (not trainable). They could be made
trainable, but that would introduce significant redundancy, and it’s not
clear that doing so would be beneficial.
Here is an example script that learns an increasing section of the sin()
function:
import torch
print (torch.__version__)
_ = torch.manual_seed (2022)
class TrainableInterpolationTable (torch.nn.Module): # doesn't need to be a Module
def __init__ (self, boundaries = None, values = None):
super().__init__()
if boundaries is None:
boundaries = torch.arange (11.)
if values is None:
values = torch.zeros (10)
assert boundaries.dim() == 1
assert values.dim() == 1
assert len (values) == len (boundaries) - 1
self.nBin = len (values) # number of bins in the interpolation table
self.boundaries = boundaries.clone() # table-entry boundaries -- must be ascending
self.values = torch.nn.Parameter (values.clone()) # logs of (positive) linear increments
def forward (self, x): # apply piecewise linear function element-wise to x
dt = self.values.dtype
inds = torch.bucketize (x, self.boundaries)
inds_cl = torch.clamp (inds, 1, self.nBin)
expval = self.values.exp()
y = (inds.unsqueeze (-1) > torch.arange (self.nBin) + 1).to (dtype = dt) @ expval
y = y + torch.logical_and (inds > 0, inds <= self.nBin).to (dtype = dt) * expval[inds_cl - 1] * (x - self.boundaries[inds_cl - 1]) / (self.boundaries[inds_cl] - self.boundaries[inds_cl - 1])
return y
tbl_ident = TrainableInterpolationTable()
x = torch.arange (-1.7, 12.0, 2.1)
print ('x = ...')
print (x)
print ('tbl_ident (x) = ...')
print (tbl_ident (x))
bnd = torch.arange (0.0, 1.005, 0.01)
val = (bnd[1:]**2 - bnd[:-1]**2).log()
tbl_square = TrainableInterpolationTable (bnd, val)
x = torch.arange (-0.1, 1.15, 0.101)
print ('x = ...')
print (x)
print ('tbl_square (x) = ...')
print (tbl_square (x))
print ('tbl_square (x) - x**2 = ...')
print (tbl_square (x) - x**2)
x = torch.rand (1000)
sin_x = x.sin()
tbl_train = TrainableInterpolationTable (torch.arange (0.0, 1.005, 0.01), -3 * torch.ones (100))
opt = torch.optim.SGD (tbl_train.parameters(), lr = 0.2, momentum = 0.95)
print ('train tbl_train to learn sine function...')
for i in range (10001):
opt.zero_grad()
mse_loss = torch.nn.functional.mse_loss (tbl_train (x), sin_x)
mse_loss.backward()
opt.step()
if not i%1000:
print ('i:', i, ' mse_loss:', mse_loss.item())
And here is its output:
1.10.2
x = ...
tensor([-1.7000, 0.4000, 2.5000, 4.6000, 6.7000, 8.8000, 10.9000])
tbl_ident (x) = ...
tensor([ 0.0000, 0.4000, 2.5000, 4.6000, 6.7000, 8.8000, 10.0000],
grad_fn=<AddBackward0>)
x = ...
tensor([-1.0000e-01, 1.0000e-03, 1.0200e-01, 2.0300e-01, 3.0400e-01,
4.0500e-01, 5.0600e-01, 6.0700e-01, 7.0800e-01, 8.0900e-01,
9.1000e-01, 1.0110e+00, 1.1120e+00])
tbl_square (x) = ...
tensor([0.0000e+00, 1.0000e-05, 1.0420e-02, 4.1230e-02, 9.2440e-02, 1.6405e-01,
2.5606e-01, 3.6847e-01, 5.0128e-01, 6.5449e-01, 8.2810e-01, 1.0000e+00,
1.0000e+00], grad_fn=<AddBackward0>)
tbl_square (x) - x**2 = ...
tensor([-1.0000e-02, 9.0000e-06, 1.6001e-05, 2.1003e-05, 2.3991e-05,
2.4989e-05, 2.3991e-05, 2.0981e-05, 1.5974e-05, 9.0003e-06,
5.9605e-08, -2.2121e-02, -2.3654e-01], grad_fn=<SubBackward0>)
train tbl_train to learn sine function...
i: 0 mse_loss: 5.558672904968262
i: 1000 mse_loss: 0.0014287722297012806
i: 2000 mse_loss: 8.150507346726954e-05
i: 3000 mse_loss: 4.557810825644992e-05
i: 4000 mse_loss: 3.0471503123408183e-05
i: 5000 mse_loss: 2.2024350982974283e-05
i: 6000 mse_loss: 1.6821579265524633e-05
i: 7000 mse_loss: 1.3377467439568136e-05
i: 8000 mse_loss: 1.096794949262403e-05
i: 9000 mse_loss: 9.207196853822097e-06
i: 10000 mse_loss: 7.875628398323897e-06
Best.
K. Frank