One hot for unequal class batch in semantic segmentation

Let suppose i have 60 classes and not all images have 60 classes.
So let say my first batch has 40 classes, so it will encoded it up to 40 classes and my second batch has 50 classes, so it encode it up to 50 classes. Shouldnot it be 60 class encoding in all batches?

Hi Talha!

Yes, you should encode all 60 classes for all of your batches even
if some batches don’t happen to have all 60 classes in them.

To expand on this a little: Your network architecture will have the
60 classes built into it (at least in standard approaches). Let’s say
that we label your 60 classes with the integers 0, …, 59. Now let’s
say that a particular batch happens to have only three classes in it,
specifically 3, 27, and 45. If you were to label this batch with, say, 0,
1, and 2, so that you gave class-27 the label 1, your network wouldn’t
know what that meant.

As an aside, whether you use one-hot encoding or something else
to label your classes is irrelevant to this question – you need all 60
either way. Having said that, you shouldn’t be using one-hot encoding.
Semantic segmentation is a classification problem (You classify
individual pixels.), so you would most typically use some version
of a cross-entropy-like loss function. Pytorch’s cross-entropy loss
functions use integer categorical class labels (like I alluded to, above),
and don’t work with one-hot encoding.

(You should also be aware that people sometimes – mistakenly – use
the term “one hot” to refer class labels of any sort, even if they are
not actually one-hot encoded. This, of course, can be confusing.)

Best.

K. Frank

I am using pytorch one hot function for this, in each batch I applied this function. do I need to some thing extra or this function internally manage that.
what other way is that find those images which have all 60 labels, use sklean one hot encoder and fit on it then on all rest of images, just transform it

Hi Talha!

Why are you using torch.nn.functional.one_hot()? I noted in my
previous that you probably don’t want to use one_hot(). What
is your specific goal?

Why do you want to find those images that have all 60 labels?
(It’s at least a possibility that no single image has all of the labels.)

What is your actual use case? What is your larger goal here?

Best.

K. Frank

I am trying to calculate dice coefficient for that purpose I need to average each class. for that I need one hot encoded vector. here is code

from einops import rearrange
def dice_coef_multilabel(mask_pred,mask_gt  ):
    def compute_dice_coefficient(mask_pred,mask_gt, smooth = 0.0001):
      """Compute soerensen-dice coefficient.
      compute the soerensen-dice coefficient between the ground truth mask `mask_gt`
      and the predicted mask `mask_pred`. 

      Args:
        mask_gt: 4-dim Numpy array of type bool. The ground truth mask. [B, 1, H, W]
        mask_pred: 4-dim Numpy array of type bool. The predicted mask. [B, C, H, W]
      Returns:
        the dice coeffcient as float. If both masks are empty, the result is NaN
      """
      volume_sum = mask_gt.sum() + mask_pred.sum()
      
      volume_intersect = (mask_gt * mask_pred).sum()
      return (2*volume_intersect+smooth) / (volume_sum+smooth)
    dice=0
    n_pred_ch = mask_pred.shape[1]
    mask_pred=torch.softmax(mask_pred, 1)
    mask_gt=F.one_hot(mask_gt.long(), num_classes=n_pred_ch) #create one hot vector
    mask_gt=rearrange(mask_gt, 'd0 d1 d2 d3 d4 -> d0 (d1 d4) d2 d3 ')  #reshape one hot vector

    for ind in range(1,n_pred_ch):
        dice += compute_dice_coefficient(mask_gt[:,ind,:,:], mask_pred[:,ind,:,:])
    return dice/n_pred_ch # taking average

Hi Talha!

Okay, for Dice loss, one_hot() makes sense.

This line from the code you’ve posted seems to indicate that you’ve
answered your own question.

To repeat my question from my previous post: Why do you want to
find those images that have all 60 labels?

As an aside, some people find that for semantic segmentation using
Dice loss to augment CrossEntropyLoss works better than using
Dice loss in isolation. You might want to explore that after you get
your Dice loss working.

Best.

K. Frank