Multiclass dice loss for scene labeling problems

I modified the code from the kornia library link for the dice loss metric. This library is designed for semantic segmentation tasks and expects tensors in the shape [Batch, Class/Logit, Height, Width]. For scene labeling tasks, the expected shape is [Batch, Class/Logit]. So, I made a simple modification to the code: I removed dimensions with IDs 2 and 3 from the dims argument. For a macro-average, the goal is to calculate the metric separately for each class then take an average. This is accomplished by not specifying a dimension argument in torch.sum(). To obtain a micro-average, where all classes are pooled to calculate the total count of TPs and errors (TPs and TNs), the dimension argument is set to 1. To obtain a single metric in the case of macro averaging, the results are averaged using torch.mean().

The math and logic are implemented in a function, which is then called within a subclass of nn_module().

I believe that the function is working correctly but would appreciate any input or a double-check of the implementation. Is the metric being calculated correctly and are macro/micro averaging correctly defined? Any input is appreciated.

import torch
import torch.nn as nn
import torch.nn.functional as F

def dice_loss(pred, ref, nCls, average, eps: float = 1e-8):
    # Input tensors will have shape (Batch, Class)
    # Dimension 0 = batch
    # Dimension 1 = class code or predicted logit
    # compute softmax over the classes axis to convert logits to probabilities
    pred_soft = torch.softmax(pred, dim=1)

    # create reference one hot tensors
    ref_one_hot = F.one_hot(ref, num_classes = nCls)

    #Calculate the dice loss
    if average == "micro":
      #Use dim=1 to aggregate results across all classes
      intersection = torch.sum(pred_soft * ref_one_hot, dim=1)
      cardinality = torch.sum(pred_soft + ref_one_hot, dim=1)
      #With no dim argument, will be calculated separately for each class
      intersection = torch.sum(pred_soft * ref_one_hot)
      cardinality = torch.sum(pred_soft + ref_one_hot)

    dice_score = 2.0 * intersection / (cardinality + eps)
    dice_loss = -dice_score + 1.0

    # reduce the loss across samples (and classes in case of `macro` averaging)
    dice_loss = torch.mean(dice_loss)

    return dice_loss

class DiceLoss(nn.Module):
    def __init__(self, nCls, average, eps: float = 1e-8) -> None:
        self.nCls = nCls
        self.average = average
        self.eps = eps

    def forward(self, pred, ref):
        return dice_loss(pred, ref, self.nCls, self.average, self.eps)