Multi class dice loss function

Hello everyone, i am trying to use dice loss for my 3D point cloud semantic segmentation model.
Although, I have implemented the function by referencing some of the codes, I am not sure whether it is correct as my IoU for my validation set does not increase compare to using cross entropy loss solely.
Below is my function for multi class dice loss:

def diceLoss(prediction_g, label_g, num_class, epsilon=1):
    ls = []
    diceRatio_g = 0
    label_g= one_hot(label_g,15).to(device) #one hot encode 15 classes
    label_g = label_g.reshape(batch_size*16384, -1) #16384 is number of point cloud (batchsize*n_pts,15)
    prediction_g = prediction_g.reshape(batch_size*16384,-1) #reshape so (batchsize*n_pts,15)

    for i in range(num_class):
        pred = prediction_g
        label = label_g
        pred=torch.nn.functional.softmax(pred, dim=1)[:, i] #select ith index in softmax
        pred = pred.reshape(-1,1) #bs*pts*1

        diceLabel_g = label.sum(dim=0)
        diceLabel_g = diceLabel_g[i]
        dicePrediction_g = pred.sum(dim=0)
        diceCorrect_g = (pred * label)[:,0]
        diceCorrect_g = diceCorrect_g .sum(dim=0)
        diceRatio_g += (2 * diceCorrect_g + epsilon) \
        / (dicePrediction_g + diceLabel_g + epsilon)
    loss = 1-(1/num_class)*diceRatio_g
    return loss

Please have a look and let me know if there is any problem. Thank you.

Hi @edshkim98,

I think it could be because of the diceCorrect_g = (pred * label)[:,0], because pred is of shape [B*N, 1] and represents the models output for the selected class i, while label is still of shape [B*N, C]. So multiplying each other you get diceCorrect_g, which should be of shape [B*N, C] and then you select the first element of dim 1, which probably is mostly zero.
Either you set label = label_g[:, i] (where i denotes your class) or I think you can actually remove the for loop totally and just do diceCorrect_g = (label_g * softmax(prediction_g, dim=-1)).sum() and dicePrediction_g = dicePrediction_g .sum() diceLabel_g = diceLabel_g .sum()

1 Like