I modified the code from the kornia library link for the dice loss metric. This library is designed for semantic segmentation tasks and expects tensors in the shape [Batch, Class/Logit, Height, Width]. For scene labeling tasks, the expected shape is [Batch, Class/Logit]. So, I made a simple modification to the code: I removed dimensions with IDs 2 and 3 from the dims argument. For a macro-average, the goal is to calculate the metric separately for each class then take an average. This is accomplished by not specifying a dimension argument in torch.sum(). To obtain a micro-average, where all classes are pooled to calculate the total count of TPs and errors (TPs and TNs), the dimension argument is set to 1. To obtain a single metric in the case of macro averaging, the results are averaged using torch.mean().
The math and logic are implemented in a function, which is then called within a subclass of nn_module().
I believe that the function is working correctly but would appreciate any input or a double-check of the implementation. Is the metric being calculated correctly and are macro/micro averaging correctly defined? Any input is appreciated.
import torch
import torch.nn as nn
import torch.nn.functional as F
def dice_loss(pred, ref, nCls, average, eps: float = 1e-8):
# Input tensors will have shape (Batch, Class)
# Dimension 0 = batch
# Dimension 1 = class code or predicted logit
# compute softmax over the classes axis to convert logits to probabilities
pred_soft = torch.softmax(pred, dim=1)
# create reference one hot tensors
ref_one_hot = F.one_hot(ref, num_classes = nCls)
#Calculate the dice loss
if average == "micro":
#Use dim=1 to aggregate results across all classes
intersection = torch.sum(pred_soft * ref_one_hot, dim=1)
cardinality = torch.sum(pred_soft + ref_one_hot, dim=1)
else:
#With no dim argument, will be calculated separately for each class
intersection = torch.sum(pred_soft * ref_one_hot)
cardinality = torch.sum(pred_soft + ref_one_hot)
dice_score = 2.0 * intersection / (cardinality + eps)
dice_loss = -dice_score + 1.0
# reduce the loss across samples (and classes in case of `macro` averaging)
dice_loss = torch.mean(dice_loss)
return dice_loss
class DiceLoss(nn.Module):
def __init__(self, nCls, average, eps: float = 1e-8) -> None:
super().__init__()
self.nCls = nCls
self.average = average
self.eps = eps
def forward(self, pred, ref):
return dice_loss(pred, ref, self.nCls, self.average, self.eps)