(Weighted) Dice Loss only predicts 3 out of 9 classes

Hello everyone, I’m kinda new to ML and CV and I’ve been training a semantic segmentation model for my master thesis. My model stagnates after 20ish epochs which it does not with CrossEntropyLoss. Also when testing out my model it only ever predicts the first 3 out of 9 classes. 2 of those classes are predominate in my dataset while one is actually relatively seldom.

This is my code for weighted dice loss:

class ClassAverageDiceLoss(nn.Module):

    def __init__(self, num_classes, softmax_dim=None):
        super().__init__()
        self.num_classes=torch.tensor(num_classes)
        self.softmax_dim=softmax_dim

    def forward(self, logits, targets, reduction='mean', smooth=1e-5):
        probabilities=logits
        if self.softmax_dim is not None:
            probabilities = nn.Softmax(dim=self.softmax_dim)(logits)
        #end if
        
        targets_one_hot=torch.nn.functional.one_hot(targets, num_classes=self.num_classes)

        # Convert from NHWC to NCHW
        targets_one_hot = targets_one_hot.permute(0, 3, 1, 2)

        dice_coeff_individual_array=torch.zeros(self.num_classes, dtype=torch.float32).to(DEVICE)

        #calcuate class specific dice loss/score
        for n_class in range(self.num_classes):

            targets_one_hot_class_specific=targets_one_hot[:,n_class,:,:]
            probabilities_class_specific=probabilities[:,n_class,:,:]

            intersection_class_specific = (targets_one_hot_class_specific * probabilities_class_specific).sum()

            mod_a = intersection_class_specific.sum()
                # .numel() and -log() in combination, due to the loss function always trying to minimize something
            mod_b = targets_one_hot_class_specific.numel()

            dice_coeff_individual = 2. * intersection_class_specific / (mod_a + mod_b + smooth)

            dice_coeff_individual_array[n_class] = dice_coeff_individual
        
        dice_coeff_average=dice_coeff_individual_array.sum()/(self.num_classes)

        dice_loss_average=-dice_coeff_average.log()

        return dice_loss_average

I thought that dice loss was good at handling imbalanced datasets which is why I was trying that out. If somebody could check my code for the weighted dice loss, I would be really thankful!

Here a graph of my training vs validation loss

Hi Tony!

(Note, I haven’t looked at your Dice-loss code.)

There is a legitimate argument that Dice loss should help with imbalanced datasets.

But CrossEntropyLoss is a very good loss for training classification models (and
semantic segmentation is a kind of classification in that you are classifying pixels).
My intuition is that the benefits of CrossEntropyLoss – once you incorporate class
weights to address the issue of class imbalance – outweigh the benefits of Dice loss.

What is your goal here? Why did you switch to Dice loss from CrossEntropyLoss
(when CrossEntropyLoss seemed to be working better)? Is your project to compare
how well the two loss functions work or is your primary goal to train an effective
semantic-segmentation model?

My recommendation would be to use CrossEntropyLoss with class weights. If that
doesn’t work well enough and you think that Dice loss might help, I would add Dice
loss (with a tunable hyperparameter weight) to CrossEntropyLoss. That way you
get the generally good convergence properties of CrossEntropyLoss together with
any potential class-imbalance benefit from Dice loss.

Best.

K. Frank

Hi K. Frank, thanks for the quick response. I’m trying out dice loss, because my final evaluation metric is dice score and I thought that directly training with dice loss would have a positive effect on the entire procedure.

What is your goal here? Why did you switch to Dice loss from CrossEntropyLoss
(when CrossEntropyLoss seemed to be working better)?
→ Currently, I can’t say if CrossEntropyLoss works better, because I have the feeling that my DiceLoss isn’t really working as intended.

Is your project to compare how well the two loss functions work or is your primary goal to train an effective semantic-segmentation model?
→ I just want to train an effective segmentation model for my application. But since I’m also still learning, it would be interesting for me to utilize different loss functions and see how they compare.

I would add Diceloss (with a tunable hyperparameter weight) to CrossEntropyLoss.
→ Okay, so a combination of Diceloss and CrossEntropyLoss? Would you just sum up the loss? And if so, would you still apply class weights to CrossEntropyLoss?

Thanks in advance!

Best regards, Tony

Hi Tony!

This is not an unreasonable way of looking at things and it often makes sense.
However, it can also make sense to use a loss function for training that, while
related to, is not the same as the final metric you use to decide how good your
model is. It’s perfectly possible for your training loss to have, so to speak, better
training properties than your evaluation metric does.

As a simple example, if your final metric isn’t differentiable, you can’t train with
it (using gradient-descent optimizers).

Consider ordinary classification (as distinct from semantic-segmentation pixel
classification). A common metric is the accuracy – the fraction of samples that
are correctly classified. As this involves discrete counting, it’s not differentiable.
But it is also possible to define a “probabilistic” accuracy that is differentiable.
Nonetheless, a lot of experience shows that cross entropy has better “trainability”
properties than such a probabilistic accuracy, so cross entropy is generally the
go-to loss for training classifiers.

It does appear that your Dice-loss code has some issues – see below.

That’s always worth doing. I would recommend comparing CrossEntropyLoss
(with class weights), some version of your Dice loss, and a combination of the two.

Yes (but with an adjustable weight). Something like:

total_loss = cross_entropy_loss + dice_weight * dice_loss
total_loss.backward()

Yes, absolutely (assuming that there is a significant imbalance in the pixel classes,
which there likely is).

Some comments about the Dice-loss code from your first post:

This looks unnecessary or like a typo. intersection_class_specific has already
been summed over pixels, so the .sum() in the expression for mod_a does nothing.

This is wrong. As written, mod_b is just the number of pixels and is a constant,
independent of n_class. You want something like:

mod_b = targets_one_hot_class_specific.sum()

which counts the number of pixels that are in class n_class.

I do not think that taking the log() of the Dice coefficient is unreasonable. However,
as originally proposed, Dice loss involved maximizing (a probabilistic version of) the
Dice coefficient (which is often structured as minimizing 1 - Dice_coefficient or
just -Dice_coefficient), so you might try that.

Best.

K. Frank

Hi K. Frank, thanks again for taking the time to respond to my question and for the insights.

total_loss = cross_entropy_loss + dice_weight * dice_loss
total_loss.backward()

I will definitely try this out and report whether this improved the accuracy or not.

This looks unnecessary or like a typo. intersection_class_specific has already
been summed over pixels, so the .sum() in the expression for mod_a does nothing.

Thats for pointing that out!

  mod_b = targets_one_hot_class_specific.numel()

Regarding this, I copied the basic formula from another post that I saw. I realized that this wasn’t the actual formula for DiceLoss but thought that since the -log was applied to the function, it would be beneficial to take a relatively small number, since the -log makes it bigger.

I do not think that taking the log() of the Dice coefficient is unreasonable. However,
as originally proposed, Dice loss involved maximizing (a probabilistic version of) the
Dice coefficient (which is often structured as minimizing 1 - Dice_coefficient or
just -Dice_coefficient), so you might try that.

Alright, thanks I’ll try to out too. Here, I also thought that since CrossEntropy also utilizes a log function, it would be easier during training to minimize that value (compared to a linear function).

Best regards, Tony!

Hi Tony!

I think this might be true (but I don’t know that it is).

I don’t have a sound theoretical explanation for why CrossEntropyLoss works
as well as it does.

However, when the predicted probability for the correct class goes to zero,
CrossEntropyLoss has a logarithmic divergence. My intuition has always been
that this logarithmic divergence is very helpful for training. (On the other hand,
I’ve added logarithmic divergences to other loss functions “by hand,” and doing
so hasn’t helped as much as I might have hoped.)

Best.

K. Frank

Hi Tony,

did the CrossEntropyLoss worked better in the end?