Model always predicts tensors of 0 values

Hi,

I’m trying to solve a classification problem with 2 classes for a biomedical application using images.
After reading a lot in this forum I’ve seen one option is using BCEWithLogitsLoss and one single output neuron. I have class imbalance, so I am using pos_weight parameter, although I don’t know if correctly.
I have some questions…

  1. Using Adam optimizer is a good idea? Or should I use SGD?
optimizer_ft = optim.Adam(params_to_update, hparams['learning_rate'])

num_positives = torch.tensor(sum(labels == 1), dtype=float)  #  250
num_negatives = torch.tensor(len(labels) - num_positives, dtype=float)  #  604
pos_weight = (num_negatives / num_positives) # around 2.4

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(get_device())

This is my train function:
2) Is it correct?

 for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels_raw, _, _ in dataloaders[phase]:
                inputs = inputs.float()
                labels = labels_raw.float()

                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels.unsqueeze(1))
                    _, preds = torch.max(outputs, dim=1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == torch.argmax(labels)).item()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects / len(dataloaders[phase].dataset)

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

            if phase == 'train':
                train_loss.append(epoch_loss)
                train_acc.append(epoch_acc)
            else:
                val_loss.append(epoch_loss)
                val_acc.append(epoch_acc)

            print('{} loss: {:.4f}, Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

This are always my predictions:

tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

This is a sample target for one batch of 8:

tensor([0., 0., 0., 1., 0., 0., 0., 1.], dtype=torch.float64)

There is definitely something wrong with my code but I am unable to find it. Some help will be appreciated.

I am also normalizing my images computing mean and std of the whole dataset for each channel and then normalizing each data split.
3) Should I normalize the validation/test split with the same parameters?

Thanks

Hi Nil!

With one single output neuron (if I understand correctly what you mean), the
last dimension of your outputs will have size 1, e.g., outputs would have a
shape of something like [nBatch, 1].

Your problem (assuming that your “predictions” is your preds) is that
_, preds = torch.max(outputs, dim=1) will always return 0 because
dim = 1 of outputs has size 1. Consider:

>>> import torch
>>> torch.__version__
'1.13.0'
>>> _ = torch.manual_seed (2022)
>>> model = torch.nn.Linear (10, 1)
>>> inputs = torch.randn (5, 10)   # batch size of 5
>>> outputs = model (inputs)
>>> outputs   # okay
tensor([[ 1.0431],
        [-0.8160],
        [-1.3349],
        [-0.3153],
        [ 0.9312]], grad_fn=<AddmmBackward0>)
>>> _, preds = torch.max (outputs, dim = 1)   # computes argmax along dim = 1
>>> preds   # but argmax is always 0
tensor([0, 0, 0, 0, 0])

For the binary case where you have only a single output “neuron” (that has
not been passed through a sigmoid()), you would want to compute your
integral yes-no predictions by thresholding outputs against zero:

preds = outputs > 0.0            # returns a bool
# or, if you prefer
preds = (outputs > 0.0).long()   # explicit cast to long

(You have the analogous error in your expression for running_corrects,
as well.)

Best.

K. Frank

Thanks for the reply,

you are right!

But, as explained in the documentation, BCEWithLogitsLoss:

This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.

So I guess my threshold should be

outputs > 0.5

I also do not undersand how to compute the pos_weight parameter of the loss.
Does it need to be computed on the whole dataset and pass a single number ?
pos_weight = negative_samples / positive_samples

or it has to be computed for every batch?

pos_weight = (labels == 0).sum() / labels.sum()

Hi Nil!

No. You are right that BCWithLogitsloss computes sigmoid() internally
(actually log_sigmoid()), but this doesn’t affect your outputs. (It affects
how outputs is used in the loss computation, but doesn’t modify outputs.)
They will still be logits that range from -inf to inf, so you should still
threshold against 0.0.

First, your imbalance of about 2.4 is not really very large, so you probably
don’t need pos_weight in practice. (It’s fine to use it though, and could
offer some minor benefits.)

Second, the exact value of pos_weight shouldn’t matter. If your training
is sensitive to small changes is pos_weight you probably have some sort
of instability in your training that you would also see if you varied some
other hyperparameter or used different random initialization of your network.

So it shouldn’t really matter exactly how you compute pos_weight (as long
as you don’t end up with inf). I would generally recommend computing
pos_weight for your entire training set, or if that is inconvenient, for a
representative subset of your training set.

(Note, your second expression, (labels == 0).sum() / labels.sum(),
should be something like (labels == 0).sum() / (labels == 1).sum().)

(One could plausibly argue that there is a mild theoretical justification for
computing pos_weight on a per-batch basis, but it really shouldn’t matter.)

Best.

K. Frank