About Dice loss, Generalized Dice loss

Hello All,
I am running multi-label segmentation of 3D data(batch x classes x H x W x D). The target is 1-hot encoded[all 0s and 1s].
I have broad questions about the loss to be used.
In V-net paper: https://arxiv.org/pdf/1606.04797.pdf
The dice loss used as:
image
which varies from 0~1, and intended to maximize. Like regular IoU/Dice.
In Generalized dice paper: https://arxiv.org/pdf/1707.03237.pdf
GDL loss is:
image
and the author says about the weight:

when choosing the GDLv weighting, the contribution of each label is
corrected by the inverse of its volume, thus reducing the well-known correlation
between region size and Dice score

which I understand very well.

So, when I implement both losses with the following code from: pytorch/torch/nn/functional.py at rogertrullo-dice_loss · rogertrullo/pytorch · GitHub

torch.manual_seed(1001)
out = Variable(torch.randn(3, 9, 64, 64, 64)) 
print >> tensor(5.2134) tensor(-5.4812)
seg = Variable(torch.randint(0,2,[3,9,64,64, 64])) #target is in 1-hot-encoded format

def dice_loss(prediction, target, epsilon=1e-6):
    """
    prediction is a torch variable of size BatchxnclassesxHxW representing log probabilities for each class
    target is a 1-hot representation of the groundtruth, shoud have same size as the prediction
    """
    assert prediction.size() == target.size(), "prediction sizes must be equal."
    assert prediction.dim() == 5, "prediction must be a 4D Tensor."
    uniques = np.unique(target.numpy())
    assert set(list(uniques)) <= set([0, 1]), "target must only contain zeros and ones"

    probs = F.softmax(prediction, dim=1) # channel/classwise
   
    num = probs * target # b,c,h,w--p*g
    num = torch.sum(num, dim=4)
    num = torch.sum(num, dim=3)
    num = torch.sum(num, dim=2) # b,c
    
    den1 = probs * probs # --p^2
    den1 = torch.sum(den1, dim=4) #b,c,h
    den1 = torch.sum(den1, dim=3) #b,c,h
    den1 = torch.sum(den1, dim=2) 
  
    
    den2 = target * target # --g^2
    den2 = torch.sum(den2, dim=4)
   
    den2 = torch.sum(den2, dim=3)
    den2 = torch.sum(den2, dim=2) 
    den = (den1+den2+epsilon)#.clamp(min=epsilon)
    dice=2*(num/den)
    dice_eso=dice #[:,1:]#we ignore background dice val, and take the foreground
    dice_total=torch.sum(dice_eso)/dice_eso.size(0)#divide by batch_sz

    return dice_total

dice_loss(prediction=out, target=seg)

I get a dice score of 1.9096 why is it beyond 1 and when using it as loss criterion during training, what limit it has to be maximized?

And when using GDL from: pytorch-3dunet/pytorch3dunet/unet3d/losses.py at eafaa5f830eebfb6dbc4e570d1a4c6b6e25f2a1e · wolny/pytorch-3dunet · GitHub

class GeneralizedDiceLoss(_AbstractDiceLoss):
    """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf.
    """

    def __init__(self, normalization='sigmoid', epsilon=1e-6):
        super().__init__(weight=None, normalization=normalization)
        self.epsilon = epsilon

    def dice(self, prediction, target, weight):
        assert prediction.size() == target.size(), "'prediction' and 'target' must have the same shape"
        prediction = flatten(prediction) #flatten all dimensions except channel/class
        target = flatten(target)
        target = target.float()

        if prediction.size(0) == 1:
            # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
            # put foreground and background voxels in separate channels
            prediction = torch.cat((prediction, 1 - prediction), dim=0)
            target = torch.cat((target, 1 - target), dim=0)
        w_l = target.sum(-1)
        w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
        w_l.requires_grad = False

        intersect = (prediction * target).sum(-1)
        print(intersect.shape)
        intersect = intersect * w_l

        denominator = (prediction + target).sum(-1)
        print(denominator)
        denominator = (denominator * w_l).clamp(min=self.epsilon)

        return 1 - (2 * (intersect.sum() / denominator.sum()))

GeneralizedDiceLoss(normalization='softmax').dice(prediction=out, target=seg, weight=None)

I get a dice score with the same 1.0013

Though I don’t know how GDL varies compared to the V-net Dice but can they be compared to some extent in terms of optimization?

Also, how do I use this in training?

For example,

            img, seg = img.to(device), seg.to(device)
            out = model(img)
            optimizer.zero_grad()
            loss = dice_loss_3d(out, seg).cuda()
            # print(loss)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

should work?

The generalized Dice loss is implemented in the MONAI framework. Take a look here: monai.losses.dice — MONAI 1.2.0 Documentation