Unexpected behavior of Dice Loss on extremely imbalanced binary segmentation

Hello everyone:

I’m currently working on a binary semantic segmentation task. The model takes as input grid-based rainfall data from the past three hours and predicts whether the rainfall at each grid point will exceed a certain threshold from 2 to 9 hours into the future (similar to afternoon thunderstorm forecasting).

This is a multi-channel output task where each pixel is a binary classification problem. The dataset is highly imbalanced, with a very low proportion of positive samples. I tested two datasets with different positive sample ratios:

Group A: 0.3% positive samples

Group B: 0.9% positive samples

The model I used is a U-Net. The loss function is Dice Loss, batch size is 4, and the optimizer is Adam. Below are my experimental results:

Group A:

Final CSI Score: 0.0018915347754955292

Final Recall Score: 0.4033149182796478

Final F1score Score: 0.0037759272381663322

Final Precision Score: 0.001896842964924872

Group B:

Final CSI Score: 0.002810424193739891

Final Recall Score: 0.016717325896024704

Final F1score Score: 0.005605095531791449

Final Precision Score: 0.003367003286257386

Below are my code:

from typing import Optional

import torch
from torch import nn

from kornia.core import Tensor
from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR
from kornia.losses._utils import mask_ignore_pixels
from kornia.utils.one_hot import one_hot

# based on:
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py
# https://github.com/Lightning-AI/metrics/blob/v0.11.3/src/torchmetrics/functional/classification/dice.py#L66-L207


def dice_loss(
    pred: Tensor,
    target: Tensor,
    average: str = "macro",
    eps: float = 1e-8,
) -> Tensor:
    
    KORNIA_CHECK_IS_TENSOR(pred)

    if not len(pred.shape) == 4:
        raise ValueError(f"Invalid pred shape, we expect BxNxHxW. Got: {pred.shape}")

    if not pred.shape[-2:] == target.shape[-2:]:
        raise ValueError(f"pred and target shapes must be the same. Got: {pred.shape} and {target.shape}")

    if not pred.device == target.device:
        raise ValueError(f"pred and target must be in the same device. Got: {pred.device} and {target.device}")
    num_of_classes = pred.shape[1]
    possible_average = {"micro", "macro"}
    KORNIA_CHECK(average in possible_average, f"The `average` has to be one of {possible_average}. Got: {average}")

    pred_soft: Tensor = pred.sigmoid()

    # set dimensions for the appropriate averaging
    dims: tuple[int, ...] = (2, 3)

    if average == "micro":
        dims = (1, *dims)

        pred_soft = pred_soft
        target = target

    intersection = torch.sum(pred_soft * target, dims)
    cardinality = torch.sum(pred_soft + target, dims)

    dice_score = 2.0 * intersection / (cardinality + eps)
    dice_loss = -dice_score + 1.0

    # reduce the loss across samples (and classes in case of `macro` averaging)
    if average == "macro":
        dice_loss = dice_loss.sum(-1)

    dice_loss = torch.mean(dice_loss)

    return dice_loss


class DiceLoss(nn.Module):
    r"""Criterion that computes Sørensen-Dice Coefficient loss.

    According to [1], we compute the Sørensen-Dice Coefficient as follows:

    .. math::

        \text{Dice}(x, class) = \frac{2 |X| \cap |Y|}{|X| + |Y|}

    Where:
       - :math:`X` expects to be the scores of each class.
       - :math:`Y` expects to be the one-hot tensor with the class labels.

    the loss, is finally computed as:

    .. math::

        \text{loss}(x, class) = 1 - \text{Dice}(x, class)

    Reference:
        [1] https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

    Args:
        average:
            Reduction applied in multi-class scenario:
            - ``'micro'`` [default]: Calculate the loss across all classes.
            - ``'macro'``: Calculate the loss for each class separately and average the metrics across classes.
        eps: Scalar to enforce numerical stabiliy.
        weight: weights for classes with shape :math:`(num\_of\_classes,)`.
        ignore_index: labels with this value are ignored in the loss computation.

    Shape:
        - Pred: :math:`(N, C, H, W)` where C = number of classes.
        - Target: :math:`(N, H, W)` where each value is
          :math:`0 ≤ targets[i] ≤ C-1`.

    Example:
        >>> N = 5  # num_classes
        >>> criterion = DiceLoss()
        >>> pred = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = criterion(pred, target)
        >>> output.backward()

    """

    def __init__(
        self,
        average: str = "macro",
        eps: float = 1e-8,
    ) -> None:
        super().__init__()
        self.average = average
        self.eps = eps

    def forward(self, pred: Tensor, target: Tensor) -> Tensor:
        return dice_loss(pred, target, self.average, self.eps)

I observed a very strange phenomenon:

In the dataset with a lower positive sample ratio (0.3%), the model tends to predict more positive samples, resulting in a noticeably higher recall but extremely low precision.

In contrast, with a slightly higher positive sample ratio (0.9%), the model becomes more conservative — recall drops to an almost unusable level, while precision slightly improves.

The model architecture, parameters, and training logic are exactly the same for both datasets. The only difference is the positive sample ratio, yet this shift in prediction tendency is completely reversed.

I would like to ask if anyone has encountered a similar phenomenon when using Dice Loss. Can this behavior be explained by the mathematical properties of Dice Loss?