I modified the code from the kornia library link for the dice loss metric. This library is designed for semantic segmentation tasks and expects tensors in the shape [Batch, Class/Logit, Height, Width]. For scene labeling tasks, the expected shape is [Batch, Class/Logit]. So, I made a simple modification to the code: I removed dimensions with IDs 2 and 3 from the dims argument. For a macro-average, the goal is to calculate the metric separately for each class then take an average. This is accomplished by not specifying a dimension argument in torch.sum(). To obtain a micro-average, where all classes are pooled to calculate the total count of TPs and errors (TPs and TNs), the dimension argument is set to 1. To obtain a single metric in the case of macro averaging, the results are averaged using torch.mean().
The math and logic are implemented in a function, which is then called within a subclass of nn_module().
I believe that the function is working correctly but would appreciate any input or a double-check of the implementation. Is the metric being calculated correctly and are macro/micro averaging correctly defined? Any input is appreciated.
import torch import torch.nn as nn import torch.nn.functional as F def dice_loss(pred, ref, nCls, average, eps: float = 1e-8): # Input tensors will have shape (Batch, Class) # Dimension 0 = batch # Dimension 1 = class code or predicted logit # compute softmax over the classes axis to convert logits to probabilities pred_soft = torch.softmax(pred, dim=1) # create reference one hot tensors ref_one_hot = F.one_hot(ref, num_classes = nCls) #Calculate the dice loss if average == "micro": #Use dim=1 to aggregate results across all classes intersection = torch.sum(pred_soft * ref_one_hot, dim=1) cardinality = torch.sum(pred_soft + ref_one_hot, dim=1) else: #With no dim argument, will be calculated separately for each class intersection = torch.sum(pred_soft * ref_one_hot) cardinality = torch.sum(pred_soft + ref_one_hot) dice_score = 2.0 * intersection / (cardinality + eps) dice_loss = -dice_score + 1.0 # reduce the loss across samples (and classes in case of `macro` averaging) dice_loss = torch.mean(dice_loss) return dice_loss class DiceLoss(nn.Module): def __init__(self, nCls, average, eps: float = 1e-8) -> None: super().__init__() self.nCls = nCls self.average = average self.eps = eps def forward(self, pred, ref): return dice_loss(pred, ref, self.nCls, self.average, self.eps)