What is the best loss function for a Binary Segmentation problem with a class imbalance

I tried FocalLoss, TverskyLoss, Focal_Tversky_Loss, DiceLoss, BCEWithLogitsLoss , none of them work well for now. P.S I also have data problem, the number of images are few (around 4k for training, 1k for validation) and maybe that is the source of issue.

What are some best case practices for binary segmentation with class imbalance (foreground pixels are few related to background pixels)

Hi Mahammad!

In what sense did they not “work well?”

This post suggests using BCEWithLogitsLoss together with its
pos_weight constructor argument, possibly augmented with something
like (a differentiable version of) the Dice coefficient (if you can show that
it actually improves training and performance).

Have you tried that yet?

Best.

K. Frank

Hello, thanks a lot for the answer.

What I meant was, they all overfit, after a certain epoch (usually 6-7th) the validation loss increase while the training loss decreases. I removed early stopping to see if there will be any progress but unfortunately the scenario lasted forever. I think my data is problematic (5k training images, 1k validation and 1k for test), and I don’t have enough data, because the images I use for training are from the same scene, same place. Although objects are moving and changing the location I think it still affects the training. I tried BCEWithLogitsLoss with different pos_weight, values (starting from 100 to 600) but they all overfit so early (after 20th epoch max).

I also tried different drop out rates to combat overfitting unfortunately that did not help too

I will try to combine Dice coefficient with pos_weight as you advised and see what happens.

Thanks again for your time.

Hi Mahammad!

Yes, that does sound like overfitting (although earlier in your training than
I would have expected).

In general, that would sound like a reasonable amount of data (but not
a whole lot).

Yes, if your data samples are not independent of one another (another
a way of saying it, if they are highly correlated), it counts as having less
data. So you may have, effectively, many fewer than 5k training images.
Depending on how “correlated” your training images are, this could be a
significant problem.

You don’t say so explicitly, but you make it sound like your validation and
test images might not be from the same scene as your training images.
Even if they’re from the same scene, they might be from different time
periods than the training images.

Either way, I would suggest (if your not doing this already) that you combine
your current training, validation, and test datasets into a single 7k dataset,
shuffle it, and then randomly divide it into new training, validation, and test
datasets. This will ensure that these datasets all have the same character
statistically.

This may or may not be appropriate for your use case, but I would now
expect you to be able to train using your new 5k training dataset without
overfitting setting in before you can perform reasonably well on your
validation dataset. Even if your training dataset is effectively “smaller”
than 5k because the images are “correlated,” your validation images will
be equally well “correlated” with your training images, making it easier for
your model to perform well on your validation dataset.

The issue of data imbalance and overfitting are pretty much independent
of one another. Using pos_weight makes sense for the problem you’ve
described (so keep using it), but it’s unlikely to help with overfitting.

Dropout is a reasonable approach to address overfitting, so consider
adding back to your model as you continue your experiments. You might
also try data augmentation (but if your images are already highly correlated,
this might not help a lot) and weight decay.

If, after shuffling your whole combined 7k dataset and trying the above,
you’re still having trouble with overfitting, you will probably have to reduce
the “capacity” of your model by having fewer and / or “narrower” (that is,
smaller values for out_channels and out_features) layers.

At this point, I wouldn’t worry about Dice coefficient. Get your overfitting
problem solved first, and only then – if you clearly are having problems
with data imbalance – add the complication of incorporating something
like Dice coefficient.

Best.

K. Frank

Thanks for the quick replies, let me give more info about the dataset.

The training data are always the same scene (same place) with different objects and lightning conditions sometimes. Anyway, I consider this as ‘highly correlated’ as you say.

The validation and test set have 3 different scenes each, but the number of images are less.

I’m afraid if I shuffle them I can have data leak. I manually combined 1 validation and 1 training scene with training data, it did not work out either.

Do you think SMOTE technique can help ? As I know it is used for tabular data but I saw an implementation for binary segmentation too.

In conclusion, I think the only solution would be try to find more image data which is not highly correlated.

About making the model narrower, does it mean to simplify the model ? I thought the opposite, and tried a deeper feature extractor (ResNet101 instead of ResNet50) but it overfitted even faster

Hi Mahammad!

Okay, as I understand it, you want to train on data from scene A.

But you want your model to work on data from scenes B, C, D and E, F, G.

Yes, given your use case, this could be considered a kind of data leak.

This is an issue all the time in the real world. Let’s say you train a
self-driving-vehicle model on dirt road A, city street B, and highway C.
You really can’t expect such a model to work on a bunch of other roads.

But if you train on a lot of different dirt roads and city streets and highways,
your model might work on various roads it wasn’t trained on.

Clear your mind of thoughts of imbalanced data.

Your problem is that your training data is not sufficiently representative
of the data you want to apply your model to (hence leading to overfitting
issues). (Note, techniques that reduce overfitting can help your model
“generalize” to data that differs somewhat in character from the data it
was trained on. But you can only push this so far – at some point you
have to train on data that is reasonably representative of the data you
actually want to apply it to.)

Issues of data imbalance are independent of your main issue, and my
working assumption is that pos_weight will be enough to address the
data imbalance you have.

Yes. Going back to your notion of “scenes,” let’s say that your goal is to
have your model work on data from scenes B through G. But when you
train on data from scene A, your model turns out to learn about the specific
details
of scene A, rather than the general character of your actual use
case. (This effect is indeed what we call overfitting.)

It’s fair to consider training on data from scenes B through G a data leak.

So what you need to do is train on data from other scenes, say H, I, J, K, …,
that are representative of (but not the same as) the scenes you want to
apply your model to.

Your not “cheating” (data leak) because your model isn’t learning any details
of scenes B through G. Instead, you’re not learning the specific details of
scene A (because your model is being forced to learn the shared general
character of a bunch of different scenes). As long as the “details” of your
training scenes are representative of (but not the same as) the "details* of
the scenes to which you will apply your model, you should be able to train
your model to work for your use case.

Yes – basically fewer parameters.

This was to be expected. The greater the “capacity” of your model (more
parameters, roughly speaking), the more capacity it will have to learn the
irrelevant details of your training data and overfit.

As a rule of thumb, if you can’t fit (including overfit) your training data,
you might try a model with greater capacity, but if you overfit too easily,
you might try a model with lower capacity.

One last note: You can sometimes address overfitting by fine-tuning a
pre-trained version of your model, rather than training the model from
scratch. But this depends on having an appropriately pre-trained model.

Best.

K. Frank

2 Likes

Thank you sir, I learnt a lot from your answers.