Hi Zannas!
Yes, it is possible to approximate a step function with a relatively simple network.
I do believe that your model, together with your training data, has multiple local minima.
Note that you are training on a single batch of samples – your original linspace()
inputs
. My intuition is that this is quite “spiky” or comb-like and perhaps you get local
minima based on how your linspace()
values line up with the location of your step.
So the first thing I did is train with rand()
values as the inputs
training samples (with
the target outputs
scaled as in your example). This helps empirically, presumably
because it smooths out the spikiness and because it adds “noise” to the training process,
helping jostle the training out of local minima. (Each forward pass uses a different set of
samples from rand()
.)
The next thing I did is “widen” your model, increasing the out_features
of linear1
and in_features
of linear2
from 1
to 8
. The idea is that the model now has more
directions is which it can escape what would have been a local minimum had the model
been narrower. (It is also possible that the wider model can better approximate your
step function – this will in general be true – but I don’t think that this is what is going on
here.)
(Think of a hillside with a number of little creeks flowing down it. If you build a number
of, say, east-west dams on that hillside, some of the creeks will turn into dammed-up
ponds – that is, local minima. Widening the model is, figuratively speaking, like removing
some of those dams.)
Last I added a learning-rate scheduler, CycleLR
, to training of the wider model. This
increases and decreases the learning rate as training progresses. Because we are using
random (“noisy”) training samples, if the training has settled into a local minimum (or is
even just stuck on a “plateau”), then a noisy sample, coupled with the (temporarily) large
learning rate, has a better chance of jostling the training out of the minimum.
I’ve illustrated this with a sample script that implements these three ingredients – random
training samples, the wider model, and CycleLR
. It trains your original model with noisy
samples using plain-vanilla SGD
and trains the wider model with noisy samples and
SGD
wrapped in CycleLR
. In each case the training is carried out ten times with
different random initializations for the model parameters. The models are evaluated
using your non-random linspace()
sample.
For all ten runs, your original, narrower model gets stuck in what is presumably a local
minimum, with a loss value of or near 0.0658
. (From memory, I believe that I have seen
lower loss values, so don’t believe that this is the global minimum, but the particular
random initializations used in this script don’t demonstrate that.)
On the other hand, the wider model with CycleLR
can easily be trained down to a
pre-set lower loss limit of 0.005
. When training, the wider model does sometimes
show a tendency to slow down near certain loss values (including 0.0658
), presumably
indicating that it is passing across a minimum-like “plateau.” In two of the training runs,
the wider model does get stuck at the 0.0658
loss value, presumably in that pesky local
minimum.
Here is the script:
import torch
outfile = open ("step_function_fit.out", "w")
print (torch.__version__)
print (torch.__version__, file = outfile)
import math
_ = torch.manual_seed (2025)
device = 'cpu'
if torch.cuda.is_available(): device = 'cuda'
def doPrint (i): # convenience function to filter printing
vals = [0, 1, 2, 5, 10]
sec1 = 10
sec2 = 100000
stp1 = [1, 3, 10]
stp2 = range (1, 10)
if i in vals: return True
pow10 = 10**math.floor (math.log10 (i))
if pow10 >= sec1 and pow10 <= sec2:
if i / pow10 in stp1: return True
if pow10 >= sec2:
if i / pow10 in stp2: return True
return False
nEpoch = 1000001 # maximum number of epochs to train
lossLim = 0.005 # train no further than this validation loss
nModels = 10 # number of models to train
inputs = torch.linspace (0.0, 1000.0, 5001) / 1000.0 # validation input with normalization
inputs = inputs.unsqueeze (1).to (device) # make it a batch of size 5001
fGoal = lambda x: torch.where (x > 100, 1000.0, 0) # the step function to fit
outputs = fGoal (inputs * 1000.0) / 1000.0 # validation output with normalization
class Model (torch.nn.Module): # original two-layer model
def __init__ (self):
super (Model, self).__init__()
self.linear1 = torch.nn.Linear (1, 1)
self.linear2 = torch.nn.Linear (1, 1)
negativeSlopeRelu = 0.001
self.relu = torch.nn.LeakyReLU (negative_slope = negativeSlopeRelu)
def forward (self, x):
y1 = self.linear1(x)
y2 = self.relu(y1)
y3 = self.linear2(y2)
y4 = self.relu(y3)
return y4
class ModelB (torch.nn.Module): # wider two-layer model
def __init__ (self):
super().__init__()
self.linear1 = torch.nn.Linear (1, 8)
self.linear2 = torch.nn.Linear (8, 1)
negativeSlopeRelu = 0.001
self.relu = torch.nn.LeakyReLU (negative_slope = negativeSlopeRelu)
def forward (self, x):
y1 = self.linear1(x)
y2 = self.relu(y1)
y3 = self.linear2(y2)
y4 = self.relu(y3)
return y4
lossFn = torch.nn.MSELoss()
models = [Model().to (device) for _ in range (nModels)] # randomly-initialized original models
modelsB = [ModelB().to (device) for _ in range (nModels)] # randomly-initialized wider models
# train multiple randomly-initialized original models with plain-vanilla SGD
print ('models training ...')
print ('models training ...', file = outfile)
for (iMod, model) in enumerate (models):
opt = torch.optim.SGD (model.parameters(), lr = 0.1)
lossInit = lossFn (model (inputs), outputs).detach()
print ('iMod:', iMod, ' lossInit:', lossInit)
print ('iMod:', iMod, ' lossInit:', lossInit, file = outfile)
for i in range (nEpoch):
inp = torch.rand (100, 1, device = device) # "noisy" inputs for training
out = fGoal (inp * 1000.0) / 1000.0 # ground truth outputs for inp
pred = model (inp)
lossTrain = lossFn (pred, out)
with torch.no_grad(): lossVal = lossFn (model (inputs), outputs)
lossTrain = lossFn (pred, out)
if iMod == 0 and (doPrint (i) or i == nEpoch - 1 or lossVal <= lossLim):
print ('i:', i, ', lossTrain:', lossTrain.detach(), ', lossVal:', lossVal)
print ('i:', i, ', lossTrain:', lossTrain.detach(), ', lossVal:', lossVal, file = outfile)
if lossVal <= lossLim: break
opt.zero_grad()
lossTrain.backward()
opt.step()
lossFinl = lossFn (model (inputs), outputs).detach()
print ('iMod:', iMod, ' lossFinl:', lossFinl, ' epochs:', i)
print ('iMod:', iMod, ' lossFinl:', lossFinl, ' epochs:', i, file = outfile)
# train multiple randomly-initialized wide models with random training samples and SGD with CycleLR
print ('modelsB training ...')
print ('modelsB training ...', file = outfile)
for (iMod, model) in enumerate (modelsB):
opt = torch.optim.SGD (model.parameters(), lr = 0.1)
sched = torch.optim.lr_scheduler.CyclicLR (opt, base_lr = 0.001, max_lr = .5, step_size_up = 100, step_size_down = 900, mode = 'triangular2', cycle_momentum = False)
lossInit = lossFn (model (inputs), outputs).detach()
print ('iMod:', iMod, ' lossInit:', lossInit)
print ('iMod:', iMod, ' lossInit:', lossInit, file = outfile)
for i in range (nEpoch):
inp = torch.rand (100, 1, device = device) # "noisy" inputs for training
out = fGoal (inp * 1000.0) / 1000.0 # ground truth outputs for inp
pred = model (inp)
lossTrain = lossFn (pred, out)
with torch.no_grad(): lossVal = lossFn (model (inputs), outputs)
lossTrain = lossFn (pred, out)
if iMod == 0 and (doPrint (i) or i == nEpoch - 1 or lossVal <= lossLim):
print ('i:', i, ', lossTrain:', lossTrain.detach(), ', lossVal:', lossVal)
print ('i:', i, ', lossTrain:', lossTrain.detach(), ', lossVal:', lossVal, file = outfile)
if lossVal <= lossLim: break
opt.zero_grad()
lossTrain.backward()
opt.step()
sched.step()
lossFinl = lossFn (model (inputs), outputs).detach()
print ('iMod:', iMod, ' lossFinl:', lossFinl, ' epochs:', i)
print ('iMod:', iMod, ' lossFinl:', lossFinl, ' epochs:', i, file = outfile)
outfile.close()
And here is its output:
2.6.0+cu126
models training ...
iMod: 0 lossInit: tensor(0.9010, device='cuda:0')
i: 0 , lossTrain: tensor(0.9312, device='cuda:0') , lossVal: tensor(0.9010, device='cuda:0')
i: 1 , lossTrain: tensor(0.9012, device='cuda:0') , lossVal: tensor(0.9010, device='cuda:0')
i: 2 , lossTrain: tensor(0.9112, device='cuda:0') , lossVal: tensor(0.9010, device='cuda:0')
i: 5 , lossTrain: tensor(0.9212, device='cuda:0') , lossVal: tensor(0.9010, device='cuda:0')
i: 10 , lossTrain: tensor(0.9212, device='cuda:0') , lossVal: tensor(0.9010, device='cuda:0')
i: 30 , lossTrain: tensor(0.9112, device='cuda:0') , lossVal: tensor(0.9010, device='cuda:0')
i: 100 , lossTrain: tensor(0.9412, device='cuda:0') , lossVal: tensor(0.9009, device='cuda:0')
i: 300 , lossTrain: tensor(0.9310, device='cuda:0') , lossVal: tensor(0.9008, device='cuda:0')
i: 1000 , lossTrain: tensor(0.8905, device='cuda:0') , lossVal: tensor(0.9003, device='cuda:0')
i: 3000 , lossTrain: tensor(0.0491, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 10000 , lossTrain: tensor(0.0838, device='cuda:0') , lossVal: tensor(0.0663, device='cuda:0')
i: 30000 , lossTrain: tensor(0.0508, device='cuda:0') , lossVal: tensor(0.0662, device='cuda:0')
i: 100000 , lossTrain: tensor(0.0793, device='cuda:0') , lossVal: tensor(0.0662, device='cuda:0')
i: 200000 , lossTrain: tensor(0.0708, device='cuda:0') , lossVal: tensor(0.0664, device='cuda:0')
i: 300000 , lossTrain: tensor(0.0718, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 400000 , lossTrain: tensor(0.0695, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 500000 , lossTrain: tensor(0.0623, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 600000 , lossTrain: tensor(0.0596, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 700000 , lossTrain: tensor(0.0623, device='cuda:0') , lossVal: tensor(0.0659, device='cuda:0')
i: 800000 , lossTrain: tensor(0.0460, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 900000 , lossTrain: tensor(0.0729, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 1000000 , lossTrain: tensor(0.0540, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
iMod: 0 lossFinl: tensor(0.0658, device='cuda:0') epochs: 1000000
iMod: 1 lossInit: tensor(0.1979, device='cuda:0')
iMod: 1 lossFinl: tensor(0.0658, device='cuda:0') epochs: 1000000
iMod: 2 lossInit: tensor(0.7612, device='cuda:0')
iMod: 2 lossFinl: tensor(0.0662, device='cuda:0') epochs: 1000000
iMod: 3 lossInit: tensor(0.9003, device='cuda:0')
iMod: 3 lossFinl: tensor(0.0658, device='cuda:0') epochs: 1000000
iMod: 4 lossInit: tensor(0.9020, device='cuda:0')
iMod: 4 lossFinl: tensor(0.0658, device='cuda:0') epochs: 1000000
iMod: 5 lossInit: tensor(0.1526, device='cuda:0')
iMod: 5 lossFinl: tensor(0.0659, device='cuda:0') epochs: 1000000
iMod: 6 lossInit: tensor(0.0940, device='cuda:0')
iMod: 6 lossFinl: tensor(0.0659, device='cuda:0') epochs: 1000000
iMod: 7 lossInit: tensor(0.8540, device='cuda:0')
iMod: 7 lossFinl: tensor(0.0659, device='cuda:0') epochs: 1000000
iMod: 8 lossInit: tensor(0.9013, device='cuda:0')
iMod: 8 lossFinl: tensor(0.0659, device='cuda:0') epochs: 1000000
iMod: 9 lossInit: tensor(0.4250, device='cuda:0')
iMod: 9 lossFinl: tensor(0.0658, device='cuda:0') epochs: 1000000
modelsB training ...
iMod: 0 lossInit: tensor(0.8999, device='cuda:0')
i: 0 , lossTrain: tensor(0.9101, device='cuda:0') , lossVal: tensor(0.8999, device='cuda:0')
i: 1 , lossTrain: tensor(0.9301, device='cuda:0') , lossVal: tensor(0.8999, device='cuda:0')
i: 2 , lossTrain: tensor(0.9001, device='cuda:0') , lossVal: tensor(0.8999, device='cuda:0')
i: 5 , lossTrain: tensor(0.9301, device='cuda:0') , lossVal: tensor(0.8999, device='cuda:0')
i: 10 , lossTrain: tensor(0.9201, device='cuda:0') , lossVal: tensor(0.8999, device='cuda:0')
i: 30 , lossTrain: tensor(0.9401, device='cuda:0') , lossVal: tensor(0.8999, device='cuda:0')
i: 100 , lossTrain: tensor(0.0981, device='cuda:0') , lossVal: tensor(0.0890, device='cuda:0')
i: 300 , lossTrain: tensor(0.0872, device='cuda:0') , lossVal: tensor(0.0668, device='cuda:0')
i: 1000 , lossTrain: tensor(0.0641, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 3000 , lossTrain: tensor(0.0580, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 10000 , lossTrain: tensor(0.0463, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 30000 , lossTrain: tensor(0.0554, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 100000 , lossTrain: tensor(0.0710, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 200000 , lossTrain: tensor(0.0662, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 300000 , lossTrain: tensor(0.0634, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 400000 , lossTrain: tensor(0.0883, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 500000 , lossTrain: tensor(0.0687, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 600000 , lossTrain: tensor(0.0532, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 700000 , lossTrain: tensor(0.0590, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 800000 , lossTrain: tensor(0.0639, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 900000 , lossTrain: tensor(0.0532, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
i: 1000000 , lossTrain: tensor(0.0811, device='cuda:0') , lossVal: tensor(0.0658, device='cuda:0')
iMod: 0 lossFinl: tensor(0.0658, device='cuda:0') epochs: 1000000
iMod: 1 lossInit: tensor(0.8238, device='cuda:0')
iMod: 1 lossFinl: tensor(0.0050, device='cuda:0') epochs: 393972
iMod: 2 lossInit: tensor(0.2943, device='cuda:0')
iMod: 2 lossFinl: tensor(0.0050, device='cuda:0') epochs: 440862
iMod: 3 lossInit: tensor(0.3555, device='cuda:0')
iMod: 3 lossFinl: tensor(0.0050, device='cuda:0') epochs: 403968
iMod: 4 lossInit: tensor(0.4646, device='cuda:0')
iMod: 4 lossFinl: tensor(0.0050, device='cuda:0') epochs: 558367
iMod: 5 lossInit: tensor(0.9003, device='cuda:0')
iMod: 5 lossFinl: tensor(0.0050, device='cuda:0') epochs: 648971
iMod: 6 lossInit: tensor(0.7196, device='cuda:0')
iMod: 6 lossFinl: tensor(0.0050, device='cuda:0') epochs: 412362
iMod: 7 lossInit: tensor(0.8966, device='cuda:0')
iMod: 7 lossFinl: tensor(0.0050, device='cuda:0') epochs: 430984
iMod: 8 lossInit: tensor(0.8642, device='cuda:0')
iMod: 8 lossFinl: tensor(0.0050, device='cuda:0') epochs: 383361
iMod: 9 lossInit: tensor(0.3980, device='cuda:0')
iMod: 9 lossFinl: tensor(0.0658, device='cuda:0') epochs: 1000000
For compactness, I only printed out the progess of the first of the ten narrower-model
runs and the first of the ten wider-model runs. If you print out more of them you can see
that some random initializations (and presumably some random choices of training data)
make it “easier” for the model to train well, yieding faster convergence and / or helping the
model escape when it starts to get stuck in a minimum-like location.
Best.
K. Frank