Binary classification model average loss not changing in training

I’m training a binary classification model that takes in a list of numerical values and then classifies them based on a binary label. It’s an unbalanced dataset, about 95% 0s and about 5% 1s. During training the average loss doesn’t change at all. I have played around with the weights and I have gotten the average loss to invert, but never actually change in the training process.

Model:

class Network(nn.Module):

    def __init__(self, inputSize):
        super(Network, self).__init__()
        self.inputSize = inputSize
        self.fc1 = nn.Linear(inputSize, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 1)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x



model = Network(48).double().to(device)
loss_fn = nn.BCELoss().to(device) 
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, X in enumerate(dataloader):
        X, y = torch.tensor(X['Text']).to(device), X['Class'].double().to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 500 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X in dataloader:
            X, y = torch.tensor(X['Text']).to(device), X['Class'].double().to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


I get that some of the size calculations are a bit off, but I don’t think that changes the outputed loss value. I’ve trained up to 20 epochs (the dataset has millions of entries) and there has been no change the average loss up to all five decimal places. The accuracy sits at the percentage of 0s in the dataset.

Hi Malachi!

I would suggest first making sure that you can overfit a subset of your
data. Take a balanced subset (equal number of 0s and 1s) of your
dataset – perhaps something like 100 or 1000 samples – and run lots
of epochs on it. You should be able to train it to get essentially 100%
accuracy because the model will be learning the specific samples that
it keeps seeing again and again (rather than learning the general
features that it would need to predict effectively on your test set).

If you get that working, try overfitting an unbalanced subset (both with
and without weights).

If you can’t get overfitting to work, you almost certainly have a bug
somewhere.

I would recommend using BCEWithLogitsLoss instead of BCELoss.
(If you do, you need to get rid of the final sigmoid() activation in
your model as BCEWithLogitsLoss has, in effect, sigmoid() built in.)

BCEWithLogitsLoss has better numerical stability than BCELoss and
it also supports the pos_weight constructor argument that will be the
most convenient way to reweight your unbalanced dataset.

(pred == y) performs an exact equality test on floating-point numbers
which (for useful values of pred) will fail, because two floating-point
numbers that “ought” to be equal will typically differ by floating-point
round-off error. You should threshold pred to turn it into a binary 0.0 vs
1.0 prediction. (If you use BCEWithLogitsLoss – as you should – you
should threshold against 0.0. If you insist upon using BCELoss, you
should threshold against 0.5.)

This suggests that all (or almost all) of your predictions are (exactly) 0.0.
This would be a little unexpected, although possible, because your
sigmoid() could be underflowing (which is part to the reason to use
BCEWithLogitsLoss without the sigmoid()).

But, as noted above, first start with the overfitting test and see if you can
get that to work.

Best.

K. Frank

Thanks for the help.
I took a balanced sample of my dataset, switched to the BCEWithLogitsLoss and used this function:

def binary_acc(y_pred, y_test):
    y_pred_tag = torch.round(torch.sigmoid(y_pred))

    correct_results_sum = (y_pred_tag == y_test).sum().float()
    acc = correct_results_sum/y_test.shape[0]
    acc = torch.round(acc * 100)
    
    return acc

to calculate the accuracy.
The average loss value is now changing. It fluctuates up and down and doesn’t seem to follow a trend, and the accuracy is equal to 50.0% with the balanced dataset which would seem to mean that the model is still picking just one answer.

Hi Malachi!

Just to confirm: Did you remove the sigmoid() from Network when
you switched to BCEWithLogitsLoss?

Did you try to overfit a small subset of your data? If so, how many
individual data samples did you use, how many epochs did you run,
and what batch size did you use?

You might try increasing your learning rate (lr) to see if you can get
your model to train faster. (You might also try decreasing your learning
rate in case your training has become unstable, but that seems unlikely.)

Have you looked at the predictions? Are they all just one answer?
Even if your “rounded” binary predictions are all the same, do you
get any significant variation in the unrounded prediction that come
directly from Network?

Best.

K. Frank

Yes, I did remove it.

I did an overfit with about 500,000 samples, fed in by batches of 1 over 100 epoches. I recorded the predictions and at the end the model was just returning vales between -2 and -4 before rounding. There did seem to be any consistency between the values, and they all round to 0.
I tried increasing the learning rate and it didn’t change anything.