How to implement a non decreasing function

I would like to learn a piecewise linear non decreasing function where both input and output are one dimensional. This can be achieved by passing through a series of ReLU layers with some constraints on the weights and biases so that the slopes are always non negative.

Can PyTorch implement this? If not, are there any alternative ways?

Hi Hnakao!

This can certainly be done.

The most natural way to represent a piecewise linear function is as a
linear-interpolation look-up table. (If each linear piece is non-decreasing,
the overall function will also be non-decreasing.)

One concrete approach is illustrated in the below script. The basic ideas
are as follows:

Use torch.bucketize to locate input within the look-up table.

Instead of storing the values of the function at the linear-interpolation
“break-point” boundaries, store the log of the differences in the values
of the function from one break-point to the next. You then get the
difference by exponentiating the log, thereby insuring that the function
is non-decreasing.

(It is possible that it will be easier to train the function differences (slopes)
rather than the logs, but then you would have to impose a positivity
constraint on the differences during training, which can introduce its own
issues.)

In my example, the log-differences (log-slopes) are trainable, but the
break-point boundaries are fixed (not trainable). They could be made
trainable, but that would introduce significant redundancy, and it’s not
clear that doing so would be beneficial.

Here is an example script that learns an increasing section of the sin()
function:

import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

class TrainableInterpolationTable (torch.nn.Module):   # doesn't need to be a Module
    def __init__ (self, boundaries = None, values = None):
        super().__init__()
        
        if  boundaries is None:
            boundaries = torch.arange (11.)
        if  values is None:
            values = torch.zeros (10)
        
        assert  boundaries.dim() == 1
        assert  values.dim() == 1
        assert  len (values) == len (boundaries) - 1
        
        self.nBin = len (values)   # number of bins in the interpolation table
        self.boundaries = boundaries.clone()   # table-entry boundaries -- must be ascending
        self.values = torch.nn.Parameter (values.clone())   # logs of (positive) linear increments
    
    def forward (self, x):   # apply piecewise linear function element-wise to x
        dt = self.values.dtype
        inds = torch.bucketize (x, self.boundaries)
        inds_cl = torch.clamp (inds, 1, self.nBin)
        expval = self.values.exp()
        y = (inds.unsqueeze (-1) > torch.arange (self.nBin) + 1).to (dtype = dt) @ expval
        y = y + torch.logical_and (inds > 0, inds <= self.nBin).to (dtype = dt) * expval[inds_cl - 1] * (x - self.boundaries[inds_cl - 1]) / (self.boundaries[inds_cl] - self.boundaries[inds_cl - 1])
        return y

tbl_ident = TrainableInterpolationTable()

x = torch.arange (-1.7, 12.0, 2.1)
print ('x = ...')
print (x)
print ('tbl_ident (x) = ...')
print (tbl_ident (x))

bnd = torch.arange (0.0, 1.005, 0.01)
val = (bnd[1:]**2 - bnd[:-1]**2).log()
tbl_square = TrainableInterpolationTable (bnd, val)

x = torch.arange (-0.1, 1.15, 0.101)
print ('x = ...')
print (x)
print ('tbl_square (x) = ...')
print (tbl_square (x))
print ('tbl_square (x) - x**2 = ...')
print (tbl_square (x) - x**2)

x = torch.rand (1000)
sin_x = x.sin()
tbl_train = TrainableInterpolationTable (torch.arange (0.0, 1.005, 0.01), -3 * torch.ones (100))
opt = torch.optim.SGD (tbl_train.parameters(), lr = 0.2, momentum = 0.95)
print ('train tbl_train to learn sine function...')
for  i in range (10001):
    opt.zero_grad()
    mse_loss = torch.nn.functional.mse_loss (tbl_train (x), sin_x)
    mse_loss.backward()
    opt.step()
    if  not i%1000:
        print ('i:', i, '  mse_loss:', mse_loss.item())

And here is its output:

1.10.2
x = ...
tensor([-1.7000,  0.4000,  2.5000,  4.6000,  6.7000,  8.8000, 10.9000])
tbl_ident (x) = ...
tensor([ 0.0000,  0.4000,  2.5000,  4.6000,  6.7000,  8.8000, 10.0000],
       grad_fn=<AddBackward0>)
x = ...
tensor([-1.0000e-01,  1.0000e-03,  1.0200e-01,  2.0300e-01,  3.0400e-01,
         4.0500e-01,  5.0600e-01,  6.0700e-01,  7.0800e-01,  8.0900e-01,
         9.1000e-01,  1.0110e+00,  1.1120e+00])
tbl_square (x) = ...
tensor([0.0000e+00, 1.0000e-05, 1.0420e-02, 4.1230e-02, 9.2440e-02, 1.6405e-01,
        2.5606e-01, 3.6847e-01, 5.0128e-01, 6.5449e-01, 8.2810e-01, 1.0000e+00,
        1.0000e+00], grad_fn=<AddBackward0>)
tbl_square (x) - x**2 = ...
tensor([-1.0000e-02,  9.0000e-06,  1.6001e-05,  2.1003e-05,  2.3991e-05,
         2.4989e-05,  2.3991e-05,  2.0981e-05,  1.5974e-05,  9.0003e-06,
         5.9605e-08, -2.2121e-02, -2.3654e-01], grad_fn=<SubBackward0>)
train tbl_train to learn sine function...
i: 0   mse_loss: 5.558672904968262
i: 1000   mse_loss: 0.0014287722297012806
i: 2000   mse_loss: 8.150507346726954e-05
i: 3000   mse_loss: 4.557810825644992e-05
i: 4000   mse_loss: 3.0471503123408183e-05
i: 5000   mse_loss: 2.2024350982974283e-05
i: 6000   mse_loss: 1.6821579265524633e-05
i: 7000   mse_loss: 1.3377467439568136e-05
i: 8000   mse_loss: 1.096794949262403e-05
i: 9000   mse_loss: 9.207196853822097e-06
i: 10000   mse_loss: 7.875628398323897e-06

Best.

K. Frank