 # Dice Loss + Cross Entropy

Hello everyone,
I don’t know if this is the right place to ask this but I’ll ask anyways.
I am working on a multi class semantic segmentation problem, and I want to use a loss function which incorporates both dice loss & cross entropy loss. How do I use this?
I dont think a simple addition of dice score + cross entropy would make sense as the dice score is a small value between 0 & 1, but the cross entropy value can also take very large values.
So is there like some kind of normalization that should be performed on the CE loss to bring it in the same scale to that of the dice loss?

Hello, I have exactly the same problem. Have you find something interesting ?

For the moment I am trying to normalize the CrossEntropyLoss by setting reduction parameter to True. The problem is as you said… CrossEntropy could take values bigger than 1.

I am actually trying with Loss = CE - log(dice_score) where dice_score is dice coefficient (opposed as the dice_loss where basically dice_loss = 1 - dice_score. I will wait for the results but some hints or help would be really helpful

Hi, sorry for the late reply. I think what could be done for cross entropy loss is :
Take the average of the negative log of the logits over 1 class and then average this value over all the classes. Doing this would give you a value in the similar scale to that of the dice loss with which you could simply add this CE value. Do you get what I mean? If not, let me know. But, when I try this method I keep getting NAN for my loss after few epochs (it is highly unpredictable when the loss would go to NAN). Did you get decent results by doing what you mentioned? i.e. taking the log of dice score?

I am not sure to have understood your point. CE works with the probabilities and then if you have a pixel with a very little probability to belong to a certain class (let´s say 1e-6), the value will be -log( p ) which in our case corresponds to ~14 which is far from the [0,1] range of the DL. I think that with nn.CrossEntropyLoss(), pytorch handle by default the normalization of the values (by taking the mean). But still at the end we don´t have the same range…

But as the values I obtain are close from [0,1] I simply add the two losses and the result is “good” in terms of mIoU. However, I truly believe that there is a “smarter” way to combine them and I am trying to finding that “smarter” way… Maybe something like DiceLoss + 0,7 * CE should be fine to “rescale” the CE

On the other hand, the results with the log of dice score doesn´t lead to great results, but I will experiment it again when my dataset will be bigger.

Finally, I don´t understand how you got NAN values from your loss, are you sure you reset the gradients properly ?

Yes sorry I didn’t mean to say logits. I meant softmax probabilities. Yeah, I’ve seen the documentation of the nn.CrossEntropyLoss, I’ll try using that instead of writing my own functon as I was doing up untill now. Also, can you please tell me if the input to the nn.CrossEntropyLoss should be softmax probabilities or raw logits? I’m not very clear after reading the documentation.
And regarding the gradients becoming NaN, I’m just doing torch.zero_grad()…is there anything else…I don’t know.

You pass in ‘raw’ logits. The criterion itself passes them through a softmax normalization.
See nn.CrossEntropyLoss: “This criterion combines `nn.LogSoftmax()` and `nn.NLLLoss()` in one single class.”

Yes as @andreaskoepf said nn.CrossEntropyLoss combines both LogSoftmax and NLL that is why I did not use CE but NLL.

Are you doing something like optimizer.zero_grad() at the beginning of each epoch?

Yeah I do that at the beginning of every iteration, not epoch. Anyways, I’ll use the inbuilt pytorch CE function and see whether I get NaN again. Thanks

Okay sure, let me know if it works.

It’s a challenging problem but you can kind of reason about it this way - each loss prioritizes different features in your image. It is up to you to decide which features are most important and then weigh the losses such that the outcome is acceptable. For example, dice loss puts more emphasis on imbalanced classes so if you weigh it more, your output will be more accurate/sensitive towards that goal. CE prioritizes the overall pixel-wise accuracy so some classes might suffer if they don’t have enough representation to influence CE. Now when you add those two together you can play around with weighing the contributions of CE vs. Dice in your function such that the result is acceptable.

There are a myriad of different loss functions out there. Here are the ones I’ve experimented so far. I’ve had good luck with `BCEDicePenalizeBorderLoss` but it’s BCE, not CE so not sure if it will work for you.

1 Like