My code needs a very specific random seed not to get stuck at the same loss

I made a very simple 3-layered fully-connected network for binary classification (using NASA C-MAPSS dataset to classify healthy and faulty turbofan engines). the input is vector of length 26 and the output is a sigmoid activation. The task is pretty easy (a basic logistic regression model gives me 100% test accuracy), I’m porting this code from Keras (where everything worked as expected) but when executing PyTorch code the loss doesn’t change. Executing the exact same code different times, most of the times I get that non-changing loss, sometimes it works and converges within a couple of epochs to 100% test accuracy as expected. I tried using Keras settings (learning rate, Adam parameters, weights initialization) yet the problem persisted. After setting the random seed (without the default weights initialization this time) I get the same result with each run (obviously!), even that the loss doesn’t change with each run but it changes from run to run. I had to try different seeds until I found a seed that actually works and the model works as expected and runs correctly every time (I set the seed to 527, other values may have worked but that’s the only one that I found).
What can be the cause of this behavior?

Here is my code and training process (data processing code is a bit long and I’m sure it’s not the problem):

class CMAPSSBinaryClassifier(nn.Module):
    def __init__(self):
        super(CMAPSSBinaryClassifier, self).__init__()
        self.fc1 = nn.Linear(26, 16)
        self.fc2 = nn.Linear(16, 4)
        self.fc3 = nn.Linear(4, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        output = torch.sigmoid(self.fc3(x))
        return output

data_path = "/home/abdeljalil/Workspace/Datasets/CMAPSS/"
data_FD01 = CMAPSSDataset(data_path, fd_number=1)
model_FD01 = CMAPSSBinaryClassifier()

#tried Xavier weights initialization scheme
#model_FD01.apply(init_weights)

loader_train, loader_test = data_FD01.construct_binary_classification_data(good_faulty_threshould=30, batch_size=64)
epochs = 100
#tried Adam with a wide range of learning rates
optimizer = optim.Adam(model_FD01.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-7)
#tried SGD with different learning rates too
#optimizer = optim.SGD(model_FD01.parameters(), lr=0.001, momentum=0.9)

model_FD01.train()
criterion = nn.BCELoss()

for epoch in range(epochs+1):
    correct = 0
    for batch_id, (data, target) in enumerate(loader_train):
        optimizer.zero_grad()
        output = model_FD01(data)
        output = output.view_as(target)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        output = torch.round(output)
        correct += output.eq(target).sum().item()

    print('Train epoch: {}\t Accuracy ({:.0f}%)\tLoss: {:.6f}'.format(epoch, 100. * correct/len(loader_train.dataset), loss.item()))

Here is a sample output (first 6 epochs) when the model doesn’t converge (like 90% of the times):

Train epoch: 0   Accuracy (50%) Loss: 13.815515
Train epoch: 1   Accuracy (50%) Loss: 13.815515
Train epoch: 2   Accuracy (50%) Loss: 13.815515
Train epoch: 3   Accuracy (50%) Loss: 13.815515
Train epoch: 4   Accuracy (50%) Loss: 13.815515
Train epoch: 5   Accuracy (50%) Loss: 13.815515

Sometimes the loss changes but stuck at around 0.69:

Train epoch: 0   Accuracy (50%) Loss: 0.704716
Train epoch: 1   Accuracy (50%) Loss: 0.701211
Train epoch: 2   Accuracy (50%) Loss: 0.698781
Train epoch: 3   Accuracy (50%) Loss: 0.697099
Train epoch: 4   Accuracy (50%) Loss: 0.695932
Train epoch: 5   Accuracy (50%) Loss: 0.695122

And the rare few times it actually works:

Train epoch: 0   Accuracy (56%)         Loss: 0.516986
Train epoch: 1   Accuracy (90%)         Loss: 0.318052
Train epoch: 2   Accuracy (100%)        Loss: 0.203251
Train epoch: 3   Accuracy (100%)        Loss: 0.136395
Train epoch: 4   Accuracy (100%)        Loss: 0.096920
Train epoch: 5   Accuracy (100%)        Loss: 0.072752
1 Like

Hello Abdeljalil!

This particular issue does not sound like a bug in your code,
although I suppose that would be possible.

More likely is that your training procedure is not robust for this
particular network / dataset.

(To clarify, is you loss literally unchanging, or is it just “stuck,” and
not making progress? Have you tried letting your training run for a
long time to see if it can get off of its plateau?)

First, I would switch to SGD instead of using the Adam optimizer.
(I don’t have much experience with Adam, but the lore seems to
be that when it works, it can converge faster, but that it gets stuck
more easily.)

Second, I would try different learning rates, both significantly larger
and significantly smaller. Depending on the nature of your plateau
or “getting stuck,” a larger or smaller learning rate can help address
the issue. If this works, but is suboptimal (e.g., too slow), you could
then try using a learning-rate scheduler.

As for the difference between pytorch and what I assume is
keras / tensorflow, speaking from memory, scale their initial
random weights differently.

However, your training procedure should be robust with respect to
the general scale of you initial weights (and certainly your random
number seed). So you could try changing the scale of your initial
weights, but even if that were to work, I would say it’s the wrong fix.

In short, I would suggest using SGD and trying various learning
rates. If something appears to work, make sure it works with most
(all?) initial random number seeds. If you get something working,
you can then back up and try to fine tune it to get it to run faster.

Good luck.

K. Frank

2 Likes

I just tried both SGD and Adam with a learnign rates from 0.000001 to 0.1 and the problem persisted. I was going to post output in this reply but preferred editing the original post. Please check it again.

Hello Abdeljalil!

It’s probably not the issue, but please also try SGD with no
momentum (momentum = 0.0).

Some comments based on your edited post:

First, get rid of the sigmoid() at the end of your model. Thus:

        output = self.fc3(x)
        return output

And use BCEWithLogitsLoss (rather than BCELoss) as your
criterion. This is in any event best practice, and it’s possible
that in your case you are getting stuck where sigmoid() saturates
to either 0.0 or 1.0.

(Note, when you convert to BCEWithLogitsLoss, when calculating
your accuracy, you should threshold the output of your model at
0.0 rather than 0.5. (Rounding a number that is between 0 and
1 is the same as thresholding it at 0.5 in the language I’m using.)

Because your accuracy gets “'stuck” at 50%, I speculate that
your dataset is exactly balanced – that is, it contains the same
number of “healthy” and “faulty” samples – and that your model
(when stuck) always predicts (close to) 0 or (close to) 1.

What do representative predictions from your model look like?
Is your data in fact balanced? Do you shuffle your data? Why
do you get exactly (to two decimal digits) 50% accuracy?
(Notice that your accuracy doesn’t move away from 50% until
your loss has fallen from about 14 to about 0.5!)

To clarify, what happens when you run your keras / tensorflow
model multiple times with different random seeds. Is it robust?

As an aside: Your output = output.view_as(target)
looks fishy. I’m guessing that it’s unnecessary (that is, that
your output and target already have the same shape).

What are the shapes of data, output (before view_as()),
and target?

Good luck.

K. Frank

I just got rid of torch.sigmoid() and used BCEWithLogitsLoss instead of BCELoss and it just worked fine as expected, there is no stuck loss problem anymore and the model converges within a couple of epochs. Thank you so much for the advice.

From PyTorch documentation:

This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.

But I’m not sure why this really solves my problem.

I really don’t know what you mean here exactly, but the model output now is just some integers, example output:

[ 7., -3., -5., -5., -5.,  4., -5., -5., -5.,  4., -5., 13., -5., -5., -5.,
-4.,  6.,  4.,  3.,  6., -3.,  3., -5.,  3., 11., -4., -4., 11., -5., -6.,
-5.,  4.,  6.,  7.,  5.,  7., -5.,  4.,  9., -5.,  8.,  4., 5.,  9., -4., 
1., 10.,  4., 14., -4., -4., -4.,  5., -5.,  6., 15., 3., -6., -5., -5.,
-5., -5., -4.,  3.]

So I just did the following:

output = torch.round(torch.sigmoid(output))

Yes that’s true, it’s because I constructed the dataset myself (extracted the first and last 30 samples from each turbofan engine simulation in C-MAPSS dataset respectively as negative and positive samples) so my dataset is exactly balanced but I never thought that would be the cause of the problem.

Yes running the same code in Keras always yields almost the same results, pretty robust.

It’s just the way my code generates data, the model output is [64, 1] while the target shape is [64] and yes I should modify the other code instead of doing this trick.

Hello Abdeljalil!

Some comments, in line, below:

Mathematically, sigmoid() maps [-inf, inf] to [0.0, 1.0].
But sigmoid() has exp() in it, which “amplifies” numbers
dramatically, so numerically the range for which sigmoid()
gives useful results is rather small.

Specifically, sigmoid (x) saturates at 1.0 already for x = 17,
and saturates at 0.0 for x = -90 (for FloatTensors). These
are not extreme values and can easily be returned by a Linear
layer.

And once the sigmoid() saturates, its gradient becomes zero,
so the optimizer won’t know which direction to go to get you
“unstuck.”

These integer values for output surely occur only after you have
calculated output = torch.round (output).

You can do this and it will give you what you want. But, to me,
it’s a little more logically straightforward to compare (“threshold”)
your pre-sigmoid outputs with 0.

The rationale is:

An output of a linear layer (a raw-score logit) of 0.0 is mapped by
sigmoid() to a probability, P, of 0.5. To convert a probability to
a discrete yes-no prediction, one would typically say P > 0.5
means “yes” (and P < 0.5 means “no”).

So instead of using sigmoid() to convert to probabilities, you
can just threshold the logits directly: logit > 0.0 means “yes”
(and logit < 0.0 means “no”).

Not that it really matters now that you have your model working,
but are you possibly using double-precision in your keras code?

Do you have an explicit sigmoid() in your keras code? If so,
for what values does the keras sigmoid() saturate?

Best.

K. Frank