BCEWithLogitsLoss validation loss did not crease after applying pos_weight

Hi everyone, I’m dealing a multilabel segmentation problem where I have 2 classes - Background class, Object class. The objects are so small so it means there is a class imbalance in data. Thus, I decided to use BceLossLogit with pos_weight.

  1. For the first epoch the training loss was so high - around 7. Is it normal to start with such a high loss - it only happens when I use pos_weight parameter.
  2. From the second epoch the loss became normal but after that the validation loss did not decrease. In general what is the best case practice to choose a loss function in such tasks?

The model outputs (also the target images) have this shape : [Batch, Number of Classes, H, W ]

 def _calculate_positive_weight(self, label_images) -> torch.tensor:
        class_counts = label_images.sum(dim=(0, 2, 3))
        total_count = label_images.numel() / label_images.size(1)
        positive_weight = total_count / (2 * class_counts[1].float())
        positive_weight_tensor = torch.tensor([positive_weight]).to(self.device)
        return positive_weight_tensor

    def _initialize_criterion(self):
        for batch in self.training_data_loader:
            label_images = batch["label_image_tensor"].float().to(self.device)
            positive_weight_tensor = self._calculate_positive_weight(label_images)
            self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=positive_weight_tensor)
            break

I appreciate your support.

Hi Mahammad!

The short story – making some assumptions about your use case and
what your model is doing – is that you should modify your model to predict
a single value per pixel and modify your target images accordingly.

A note on terminology: I think you mean you have a multiclass problem.
That is you have a single-label, two-class problem.

By way of example, a multilabel problem would be something like where
you classify images according to three attributes, say animal - cat, dog,
rabbit (three classes); type - black and white, color (two classes); focus -
blurry, sharp (two classes).

Please correct me if you really do have a multilabel problem – this matters
for how the rest of your post should be understood.

As I understand you, your model classifies each pixel to be either
“Background” or “Object.” So you have a two-class – which is to say
binary – segmentation problem.

While conceptually the same, the pytorch infrastructure distinguishes
between a binary problem and a general multiclass problem that happens
to be a two-class (and hence conceptually binary) problem.

Specifically, if you structure your problem as binary, your model should
predict a single output per pixel and that output should be (if you are
using BCEWithLogitsLoss) the logit that corresponds to the probability
of that pixel being in the “Object” class.

On the other hand, if you structure you problem as a multiclass problem,
that in your case happens to be two-class, your model should predict two
outputs per pixel that correspond to the unnormalized log-probabilities of
that pixel being in each of the two classes, “Background” and “Object.”
For such a multiclass problem – even in the case of two classes – you
would typically use CrossEntropyLoss as your loss criterion.

This says to me that your model is treating this as a multiclass problem
with two classes. But, while a legitimate approach, this will not work
correctly with BCEWithLogitsLoss as your loss criterion.

So, to repeat, modify your model to predict a single output value per
pixel (with shape [Batch, H, W ] – no class dimension), modify your
target images analogously to have a single target value per pixel, 0.0
for “Background” and 1.0 for “Object” (again with shape [Batch, H, W ]),
use BCEWithLogitsLoss, and use its pos_weight constructor argument
to compensate for your class imbalance.

(Note, BCEWithLogitsLoss will work correctly if both your predictions and
target images carry a singleton class dimension (that is, a “trivial” class
dimension of length one), but I prefer stylistically to leave such a class
dimension out.)

Best.

K. Frank

1 Like

Thank you so much for such a detailed answer.

Problem Description
You are right I confused the terminology - the problem is a binary segmentation problem. To give you more info - this is a change detection model taking 2 images and resulting the change mask.
I changed the model now.

  • Input : rgb image converted to tensor ([1, 3, 224, 224])
  • Target : Mask with 2 colors (white background, with a foreground showing objects) - converted to tensor ([1, 1, 224, 224]) - the values are binary (0 or 1)
  • Model output : ([1, 1, 224, 224]) tensor - the values inside the are not binary - to visualize the output I apply sigmoid function to make it binary
 with torch.no_grad():
        prediction = model((test_image, reference_image))
    
    output = torch.nn.functional.sigmoid(prediction)
    output = (output > 0.5).float()

Current Problems:

  1. The validation loss stops improving (aka decreasing) after a certain epoch(20-25th) so early - stopping happens .
  2. Also, validation loss is less than training loss at the beginning for the first 5-6 epochs which I believe should be the opposite

Does this mean overfitting , undeffitting or both ? Can it be because the model is too complex for the task and it learns quickly but can not generalize ? I used several techniques such as Weight Decay, Drop-out but none of them really worked.

Questions:

  1. Should I change the model architecture so that, the input and output will be [batch, H, W] or can I keep [batch, 1, H, W]. If I keep [batch, 1, H, W] - I can’t use CrossEntropy (the loss values become zero if I do)

  2. If I use pos weight parameter, how should I calculate it ? As I know it is division of negative classes/positive classes - in my case this ratio is too high because usually the images contain a tiny blue part, so I think this is a huge class imbalance. For training set the average ratio is 311, for validation set 1179, for test set 414. So should I use an intuitive general factor for this ratio (for example: pos weight = 4, meaning that there is a 80/20 ratio). As I understood this is a trade-off between precision and recall and if the pos weight is high, then the recall value starts from a too high (0.96 for example) value then decreases where all other metrics slowly increase. If I use the real ratio (which is so high) the validation loss stops improving after 6-7th epoch and early stopping happens.

The questions can be really basic or noob, I’m sorry in advance. I’m just a beginner and I’d appreciate your help.

Hi Mahammad!

This is fine.

For inference (e.g., validation), this is fine. For training (where you would
use backpropagation), you would not use with torch.no_grad():.

You haven’t said how big your dataset is, so we don’t know how many
samples you have in an epoch. However 25 epochs of training is in
general quite small. Try training much longer. It’s perfectly possible
for your training loss to “plateau” and then start making progress again
as you train more (and this can happen multiple times in a training run).

This could just be “noise.” Six epochs is basically nothing. If the random
initialization of your model happens to better match the samples randomly
selected for your validation set, your validation loss might randomly be
lower than your training loss (and you might expect to see this about half
the time). Only if your validation loss systematically remains lower than
your training loss after you have trained much longer should you start to
suspect that something fishy might be going on.

Neither. For any but the simplest toy problems, you haven’t trained long
enough to reach any conclusions.

You can do it either way – the two are equivalent. For purely stylistic
reasons, I prefer to not have the singleton dimension, but you don’t
need to change your model architecture to get rid of it – you can simply
.squeeze() away any singleton dimensions that the output of your
model might have (if you care – again, it doesn’t matter).

You don’t want to use CrossEntropyLoss. You should be using
BCEWithLogitsLoss.

I wouldn’t call this a huge imbalance, but it is large enough that you will
want to compensate for it, for example, by using pos_weight. (Some
people suggest using an intersection-over-union (IoU) or Dice-coefficient
loss for imbalanced segmentation problems. My advice is to start with
pos_weight and only augment BCEWithLogitsLoss with something
like IoU if its clear that pos_weight isn’t working well enough.)

No, a value of 4 would be way to small to effectively compensate for your
class imbalance. A value of about 400 would be a good starting point.

It’s up to you and your use case whether precision or recall is more
important to you and therefore how you should tune that trade-off.

Start by using a value that is roughly equal to the real ratio. Train for much
longer so that your results are meaningful. Then look at your precision and
recall and adjust pos_weight, as appropriate, to achieve your desired
precision / recall trade-off.

Again, 7 epochs of training is almost nothing.

Turn off early stopping (so that you can let your training run much longer).

Best.

K. Frank

1 Like