Dice Loss with custom penalities

Hi all,

I am wading through this CV problem and I am getting better results

The challenge is my images are imbalanced with background and one other class dominant. Cross Entropy was a wash but Dice Loss was showing some improvement in getting the less prevalent class but I think I need an added penalty on getting the less prevalent class wrong. They will always be less prevalent so I would just need to get the weights right.

I have been using the very handy segmentation models package but they do not provide a way of adding custom weights. So I want to modify the Dice Loss function with this upgrade.

I am confused where I should stick it in the code

from typing import Optional, List

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from ._functional import soft_dice_score, to_tensor
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE

__all__ = ["DiceLoss"]


[docs]class DiceLoss(_Loss):
    def __init__(
        self,
        mode: str,
        classes: Optional[List[int]] = None,
        log_loss: bool = False,
        from_logits: bool = True,
        smooth: float = 0.0,
        ignore_index: Optional[int] = None,
        eps: float = 1e-7,
    ):
        """Dice loss for image segmentation task.
        It supports binary, multiclass and multilabel cases

        Args:
            mode: Loss mode 'binary', 'multiclass' or 'multilabel'
            classes:  List of classes that contribute in loss computation. By default, all channels are included.
            log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff`
            from_logits: If True, assumes input is raw logits
            smooth: Smoothness constant for dice coefficient (a)
            ignore_index: Label that indicates ignored pixels (does not contribute to loss)
            eps: A small epsilon for numerical stability to avoid zero division error
                (denominator will be always greater or equal to eps)

        Shape
             - **y_pred** - torch.Tensor of shape (N, C, H, W)
             - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W)

        Reference
            https://github.com/BloodAxe/pytorch-toolbelt
        """
        assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
        super(DiceLoss, self).__init__()
        self.mode = mode
        if classes is not None:
            assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary"
            classes = to_tensor(classes, dtype=torch.long)

        self.classes = classes
        self.from_logits = from_logits
        self.smooth = smooth
        self.eps = eps
        self.log_loss = log_loss
        self.ignore_index = ignore_index

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

        assert y_true.size(0) == y_pred.size(0)

        if self.from_logits:
            # Apply activations to get [0..1] class probabilities
            # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on
            # extreme values 0 and 1
            if self.mode == MULTICLASS_MODE:
                y_pred = y_pred.log_softmax(dim=1).exp()
            else:
                y_pred = F.logsigmoid(y_pred).exp()

        bs = y_true.size(0)
        num_classes = y_pred.size(1)
        dims = (0, 2)

        if self.mode == BINARY_MODE:
            y_true = y_true.view(bs, 1, -1)
            y_pred = y_pred.view(bs, 1, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        if self.mode == MULTICLASS_MODE:
            y_true = y_true.view(bs, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask.unsqueeze(1)

                y_true = F.one_hot((y_true * mask).to(torch.long), num_classes)  # N,H*W -> N,H*W, C
                y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1)  # H, C, H*W
            else:
                y_true = F.one_hot(y_true, num_classes)  # N,H*W -> N,H*W, C
                y_true = y_true.permute(0, 2, 1)  # H, C, H*W

        if self.mode == MULTILABEL_MODE:
            y_true = y_true.view(bs, num_classes, -1)
            y_pred = y_pred.view(bs, num_classes, -1)

            if self.ignore_index is not None:
                mask = y_true != self.ignore_index
                y_pred = y_pred * mask
                y_true = y_true * mask

        scores = self.compute_score(y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims)

        if self.log_loss:
            loss = -torch.log(scores.clamp_min(self.eps))
        else:
            loss = 1.0 - scores

        # Dice loss is undefined for non-empty classes
        # So we zero contribution of channel that does not have true pixels
        # NOTE: A better workaround would be to use loss term `mean(y_pred)`
        # for this case, however it will be a modified jaccard loss

        mask = y_true.sum(dims) > 0
        loss *= mask.to(loss.dtype)

        if self.classes is not None:
            loss = loss[self.classes]

        return self.aggregate_loss(loss)

    def aggregate_loss(self, loss):
        return loss.mean()

    def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
        return soft_dice_score(output, target, smooth, eps, dims)

Where can I add my edit? I see a reference to a

from ._functional import soft_dice_score
I cannot find this any where. Ideally I would like to avoid writing something from scratch. I suspect it would be where the DICE score is actually calculated but I cannot find it.

Hi Matthew!

Please let me offer some suggestions without answering your specific
question.

I do realize that Dice Loss is often recommended for segmentation
problems with unbalanced data, but let me suggest instead using
CrossEntropyLoss with class weights (CrossEntropyLoss's weight
constructor argument).

My advice is to start with (weighted) CrossEntropyLoss, and if
that doesn’t seem to be doing well enough, try adding Dice Loss
to CrossEntropyLoss as a further contribution to the total loss.
My view is that doing so is likely to work better than using Dice Loss
in isolation (and that weighted CrossEntropyLoss is likely to work
well without adding Dice Loss).

My only comment here is that if you have access to the code at a level
where you could add custom weights to a Dice Loss, you should also be
able to swap in CrossEntropyLoss (with its built-in weight argument).

Best.

K. Frank

1 Like

Thanks for the tip! I will start with the custom weights and may add Dice later. I will keep you posted!!!

@KFrank

real quick question if I wanted to have both how would I do that?

Hi @NearsightedCV,
from what I know, dice loss for multi class is the average of dice loss for each class.
So it is balancing data in a way.
But if you want, I think you can change how to average them.

Var loss should be a vector with shape #Classes. You can multiply it with weight vector.
But I also think you can get a good enough result with weighted cross entropy.

@mMagmer I was leaning towards Dice because just switching to it seemed to give a better result without weights. I started playing around with it, and it does seem to help.

@mMagmer do you know what the best way to optimize the weighting would be?

One other thing that is confusing me, using pretrained encoders vs starting from scratch. Do I still only have 10-20 epochs to make a model before overfitting risk takes place? I am getting training scores pretty low by 25 epochs but my validation scoring is worse.

This is

for cross entropy i think it’s 1/#each_class_pixel
for dice maybe you can find somting in [1707.03237] Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations .

when choosing the GDLv weighting, the contribution of each label is
corrected by the inverse of its volume, thus reducing the well known correlation
between region size and Dice score.
i really don’t know.

when your training error is still high you’re not overfitting.