Binary classification on MNIST: loss and accuracies remain costant

I am trying to do binary classification on MNIST dataset. Class 0 for even numbers and class 1 for odd numbers. I am using a simplified version of VGG.
My NN has a loss and an accuracy that remain costant.
I want to say that my model, reached to over 90% of accuracy before of changing targets into binary targets, so probably there is something wrong.
Here I change the targets into binary:

for i in range(10):
  idx = (train_set.targets==i)
  if (i == 0) or ((i % 2) == 0): train_set.targets[idx] = 0

  else: train_set.targets[idx] = 1

for i in range(10):
  idx = (test_set.targets==i)
  if (i == 0) or ((i % 2) == 0): test_set.targets[idx] = 0

  else: test_set.targets[idx] = 1

This is my net:

class VGG16(nn.Module):

    def __init__(self, num_classes):
        super(VGG16, self).__init__()

        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      # (1(32-1)- 32 + 3)/2 = 1
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=64,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )
        
        self.block_3 = nn.Sequential(
            nn.Conv2d(in_channels=128,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_4 = nn.Sequential(
            nn.Conv2d(in_channels=256,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )            

        self.classifier = nn.Sequential(
            nn.Linear(2048, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.65),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.65),
            nn.Linear(4096, num_classes),
            nn.Sigmoid() 
        )

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
#                 nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    m.bias.detach().zero_()

        # self.avgpool = nn.AdaptiveAvgPool2d((7, 7))

    def forward(self, x):

        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = self.block_4(x)
        # x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
        #logits = self.classifier(x)
        #probas = F.softmax(logits, dim=1)
        # probas = nn.Softmax(logits)
        #return probas
        # return logits
# Define an optimizier
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr = 0.01)
# Define a loss 
criterion = nn.BCELoss()


def train(net, loaders, optimizer, criterion, epochs=20, dev=dev, save_param = False, model_name="valerio"):
    try:
        net = net.to(dev)
        #print(net)
        # Initialize history
        history_loss = {"train": [], "val": [], "test": []}
        history_accuracy = {"train": [], "val": [], "test": []}
        # Store the best val accuracy
        best_val_accuracy = 0

        # Process each epoch
        for epoch in range(epochs):
            # Initialize epoch variables
            sum_loss = {"train": 0, "val": 0, "test": 0}
            sum_accuracy = {"train": 0, "val": 0, "test": 0}
            # Process each split
            for split in ["train", "val", "test"]:
                if split == "train":
                  net.train()
                else:
                  net.eval()
                # Process each batch
                for (input, labels) in loaders[split]:
                    # Move to CUDA
                    input = input.to(dev)
                    labels = labels.to(dev)
                    # Reset gradients
                    optimizer.zero_grad()
                    # Compute output
                    pred = net(input)
                    labels = labels.unsqueeze(1)
                    labels = labels.float()
                    loss = criterion(pred, labels)
                    # Update loss
                    sum_loss[split] += loss.item()
                    # Check parameter update
                    if split == "train":
                        # Compute gradients
                        loss.backward()
                        # Optimize
                        optimizer.step()
                    # Compute accuracy
                    _,pred_labels = pred.max(1)
                    batch_accuracy = (pred_labels == labels).sum().item()/input.size(0)
                    # Update accuracy
                    sum_accuracy[split] += batch_accuracy
            # Compute epoch loss/accuracy
            epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            epoch_accuracy = {split: sum_accuracy[split]/len(loaders[split]) for split in ["train", "val", "test"]}

            # Store params at the best validation accuracy
            if save_param and epoch_accuracy["val"] > best_val_accuracy:
              #torch.save(net.state_dict(), f"{net.__class__.__name__}_best_val.pth")
              torch.save(net.state_dict(), f"{model_name}_best_val.pth")
              best_val_accuracy = epoch_accuracy["val"]

            # Update history
            for split in ["train", "val", "test"]:
                history_loss[split].append(epoch_loss[split])
                history_accuracy[split].append(epoch_accuracy[split])
            # Print info
            print(f"Epoch {epoch+1}:",
                  f"TrL={epoch_loss['train']:.4f},",
                  f"TrA={epoch_accuracy['train']:.4f},",
                  f"VL={epoch_loss['val']:.4f},",
                  f"VA={epoch_accuracy['val']:.4f},",
                  f"TeL={epoch_loss['test']:.4f},",
                  f"TeA={epoch_accuracy['test']:.4f},")
    except KeyboardInterrupt:
        print("Interrupted")
    finally:
        # Plot loss
        plt.title("Loss")
        for split in ["train", "val", "test"]:
            plt.plot(history_loss[split], label=split)
        plt.legend()
        plt.show()
        # Plot accuracy
        plt.title("Accuracy")
        for split in ["train", "val", "test"]:
            plt.plot(history_accuracy[split], label=split)
        plt.legend()
        plt.show()

From the previous model of digit recognition i changed only the targets, and the final layer of classifier from 10 classes to 1 class + Sigmoid. And i changed also cross entropy to BCELoss. What I am doing wrong?

These are loss and accuracy values:

Epoch 1: TrL=49.0955, TrA=31.4211, VL=49.7285, VA=31.7340, TeL=49.2635, TeA=31.3758,
Epoch 2: TrL=49.0992, TrA=31.4235, VL=49.7285, VA=31.7340, TeL=49.2635, TeA=31.3758,
Epoch 3: TrL=49.0899, TrA=31.4176, VL=49.7285, VA=31.7340, TeL=49.2635, TeA=31.3758,
Epoch 4: TrL=49.0936, TrA=31.4199, VL=49.7285, VA=31.7340, TeL=49.2635, TeA=31.3758,
Epoch 5: TrL=49.0936, TrA=31.4199, VL=49.7285, VA=31.7340, TeL=49.2635, TeA=31.3758,
Epoch 6: TrL=49.0825, TrA=31.4128, VL=49.7285, VA=31.7340, TeL=49.2635, TeA=31.3758,

What’s wrong? How is it possible that with 10 classes I reached over 90% accuracy, and with a simplified version, only 2 classes, I reach 30% of accuracy?

Edit: increasing batch size from 64 to 128, accuracy reaches to 60% and remains constant…

Hi Bruno!

I do see what appears to be an error in your conversion from a
multi-class to a binary model (although I can’t explain the entirety
of your results).

Specifically, I believe your use of argmax() to convert your probabilistic
predictions into “hard” yes-no binary predictions is incorrect, and will
always yield 0.

When you create your net, presumably by instantiating your VGG16,
you don’t show us what value you use for num_classes. I assume
that you use num_classes = 1. (Otherwise I would expect that
BCELoss would throw an error when you call it.)

Therefore I assume that the output of your model has shape
[nBatch, 1]. Note, even though you are performing a binary
classification, I assume that you do not use num_classes = 2
and that your output does not have shape [nBatch, 2].

(In the case of a ten-class classifier, your model output would have
shape [nBatch, 10].)

Here I believe that labels starts out having shape [nBatch], with no
class dimension, as would be appropriate for a multi-class problem
using CrossEntropyLoss. However, BCELoss requires that pred
and labels have the same shape, and I believe that pred has
shape [nBatch, 1] (where the singleton dimension is not necessary,
but is also not incorrect).

That is, pred has shape [nBatch, num_classes], which is appropriate
for both a multi-class problem with CrossEntropyLoss and for a plain
binary problem (or a multi-label, multi-class problem) with BCELoss,
with, in this case, num_classes = 1.

Therefore you unsqueeze() labels in order that labels have the
same shape as pred, namely, [nBatch, 1].

So far, so good.

Just for clarity, note that _,pred_labels = pred.max(1) is the same
as pred_labels = pred.argmax(1).

So your predicted class label is the index of the largest predicted
value along the class dimension. pred_labels takes on integer
values in [0, num_classes - 1] (inclusive). This is correct for
the multi-class case, but in the binary case (implemented with
num_classes = 1), pred_labels will always have value 0.

So the predictions you use for computing your accuracy are always
predicting “0”. (Note that this issue does not affect the predictions
you use for computing your loss. As far as I can tell, your loss should
be correct and your training should work.)

What’s odd here is that your computed accuracy of about 31% suggests
that your dataset is rather unbalanced. If my theory here is right, this
means that your dataset consists of about 31% even-number digits,
and about 69% odd-number digits. This isn’t impossible, but it does
seem a bit peculiar.

That your computed accuracy depends on your batch size is very
strange. I don’t see any batch-size-dependent error in your accuracy
calculation. You don’t show us your loaders, so one could speculate
that something weird is going on there.

A couple of minor comments:

First, when you get this otherwise working, you should consider
using BCEWithLogitsLoss and getting rid of the final Sigmoid
layer in your classifier block. BCELoss with the Sigmoid and
BCEWithLogitsLoss without are mathematically equivalent, but
BCEWithLogitsLoss is numerically more stable.

Second, this following code is not incorrect, but it’s unnecessarily
complicated:

This all can be simplified down to:

train_set.targets = train_set.targets % 2

Best.

K. Frank

Thank you for the answer, I think you are right.
Obviously when I create the model the number of classes is =1, as you said. I changed the loss from BCELoss to BCEWithLogitsLoss, and I removed Sigmoid, following your advice.
My loaders are these:

# Define dictionary of loaders
loaders = {"train": train_loader,
           "val": val_loader,
           "test": test_loader}

I don’t know that change in batch_size changed the results…
However, what do you suggest?
Do you think is correct writing this?

pred_labels = pred.argmax(1)+1

Edit: no, this solution I proposed is wrong, in this way I will obtain always 1. I don’t know how can I fix…
Can I just remove argmax/max and use simply the output?

When you said:

Blockquote
Here I believe that labels starts out having shape [nBatch], with no
class dimension, as would be appropriate for a multi-class problem
using CrossEntropyLoss. However, BCELoss requires that pred
and labels have the same shape, and I believe that pred has
shape [nBatch, 1] (where the singleton dimension is not necessary,
but is also not incorrect).

I tried to change this, using:

pred = pred.squeeze(dim=1) # Output shape is [Batch size, 1], but we want [Batch size]

But it returned an error that the shape must be the same, so I removed, avoiding other problems :smiley:

Now, I did this:

# Compute accuracy
#pred_labels = pred.argmax(1) + 1
pred_labels = (pred >= 0.5).long() # Binarize predictions to 0 and 1
batch_accuracy = (pred_labels == labels).sum().item()/input.size(0)

And now, accuracy is over 97% :smiling_face_with_three_hearts:
Do you think is a good solution?

Hi Bruno!

This potentially makes sense. Let me offer a few words of explanation:

You have a binary classifier that has a single output (per batch
sample). If this value is the output of your final Linear layer,
without a Sigmoid, it should be understood as a raw-score
logit that runs from -inf to inf and you should feed it into
BCEWithLogitsLoss (preferred). If you have a Sigmoid then
you are converting the logit to a probability in [0.0, 1.0] and
you should feed it into BCELoss (not preferred).

This probability (after the Sigmoid) is your predicted probability of
your sample being in the “1” or “yes” class. The probability of your
sample being in the “0” or “no” class is one minus the “1”-class
probability.

How do you covert such a probabilistic prediction into a hard yes-no
prediction? A reasonable and straightforward approach is to say that
if your predicted probability (for your “1” class) is greater than 1/2,
then you will make a hard prediction of your sample being in the
“1” class. Conversely, if your predicted probability is less than 1/2,
you make the “0”-class prediction.

So your (pred >= 0.5).long() is thresholding your probabilistic
prediction, pred, against 1/2 to get a hard prediction. This is correct
if pred is a probability (that is, if you have a Sigmoid after your final
Linear layer), but is imperfect if pred is a logit (that is, no Sigmoid).

What should you do if pred is a logit? Note that Sigmoid maps logits
that are negative to probabilities that are less than 1/2 (and positive
logits to probabilities that are greater than 1/2). So thresholding the
logit against 0 is equivalent to thresholding the corresponding
probability against 1/2.

Therefore if your pred is a logit you should tweak your code to read
(pred >= 0.0).long().

As noted before, BCEWithLogitsLoss is numerically more stable
than BCELoss, so you are better of using BCEWithLogitsLoss
without Sigmoid, having your predictions be logits, and therefore
thresholding your predictions against 0.

Best.

K. Frank

Thank you very much for all these exhaustive answers!!