Is this code for 3d dice loss correct?

def dice_loss(preds, labels, classes):
    """
    Masks are of the Size : (N,C,D,H,W)
    Labels are of the Size: (N,1,D,H,W)
    """
    softmax = nn.Softmax(dim=1)
    preds_prob = softmax(preds)
    preds_flat = preds_prob.view(-1, classes)

    labels_onehot = F.one_hot(labels, classes)
    labels_oh_flat = labels_onehot.view(-1, classes)

    smooth = 1e-7

    intersection = 2 * torch.sum(labels_oh_flat * preds_flat, axis=0) + smooth
    denominator = torch.sum(preds_flat * preds_flat, axis=0) + torch.sum(labels_oh_flat * labels_oh_flat,
                                                                         axis=0) + smooth
    dice_loss = -1 * torch.mean(intersection / denominator)

    return dice_loss