Logic check for multi-class Dice Loss

I have a multi-class segmentation problem, and I want to use Dice Loss to solve the class imbalance. I found a simple implementation of dice loss given by:

def criterionDice(prediction, groundTruth):
    diceScore = (2*(prediction * groundTruth).sum())/((prediction + groundTruth).sum() + 1e-8)
    diceLoss = 1 - diceScore
    return diceLoss

This works well when I have a simple binary segmentation problem, but when I move to multi-class problem, I face issues because my prediction and ground truth do not have the same dimensions, rather than using one hot encoding for labels, I left them as single dimension maps, so now my predictions and ground truth have the following shape:

Prediction shape: Batch Size, Number of Class, 256, 256
Ground truth shape: Batch Size, 256, 256

To ensure my prediction and ground truth have the same size while calculating loss, I did the following:

def criterionDice(prediction, groundTruth):
    probs = torch.nn.functional.softmax(prediction, dim=1)
    probs = torch.argmax(probs, dim=1)
    diceScore = (2*(probs * groundTruth).sum())/((probs + groundTruth).sum() + 1e-8)
    diceLoss = 1 - diceScore
    return diceLoss

But I am really not sure if this is the right way to solve the issue. Can someone please check if this is the correct logical implementation for dice loss. Any help would be really appreciated.