Mask Shape for Multi Label Image Segmentation

Hello, I was wondering which is the expected shape of a mask for a multilabel image segmentation task. My guess is that each mask’s shape is (channels = num_classes, height, width), since a segmentation model will output a tensor with as many channels as classes, but i am not entirely sure. Here is an example of a model for a multi-label segmentation task with 5 possible labels.

model = smp.Unet(
    encoder_name= "efficientnet_b0", # model backbone
    encoder_weights="imagenet", 
    in_channels=3, # 1 grayscale 3 GB
    classes= 5,
    activation=None,
)

Also, if this is true I guess there is an implicit order between channels and classes. For example, channel1 is always for class1, channel2 is always for class2, …, etc.

@ptrblck can you confirm this?

1 Like

Hi Moth!

This is basically right with the proviso that the mask tensor you pass to your
loss function should have shape [nBatch, num_classes, height, width].
Your batch size (nBatch) can be one, but the nBatch dimension is still
required.

Each element of your mask tensor will be a floating point number that is
the probability of that pixel being in the given class. These values can be
0.0 and 1.0 if you want “hard,” non-probabilistic mask values.

Your predictions – the output of your model – will have the same shape
and consist of logits that correspond to the probability that the given pixel
is in the given class. These would typically be produced by a final Conv2
layer with out_channels = num_classes (and probably a kernel_size
of 1) and no subsequent sigmoid(). Your loss function should be
BCEWithLogitsLoss.

I’m not entirely sure what you mean by this.

Your predictions and mask have the same shape. BCEWithLogitsLoss
performs an element-wise binary-cross-entropy computation (and then,
typically a reduction), so the out_channels dimension of your predictions
and the num_classes dimension of your mask do line up.

As aside, these comments apply to the multi-label case where a given
pixel can be assigned to none, some, or all of your class at the same
time. If instead you are working with what is called the multi-class case,
where a given pixel is in exactly one class at a time, then these comments
don’t directly apply.

Best.

K. Frank

2 Likes

Thanks Frank for the detailed explanation. You confirmed my assumption.

Regarding:

I guess there is an implicit order between channels and classes. For example, channel1 is always for class1, channel2 is always for class2

What I meant is that each mask’s channel is always assigned the same category/label. For example, let’s say there are 4 classes: [“cat”, “dog”, “lion”, “wolf”]. We would then have an array/tensor of shape (4, H, W) where the cat mask is always on the first channel of the mask, the dog mask is always on the second channel of the mask, the lion mask if always on the third channel of the mask, the wolf mask is always on the fourth channel of the mask. Essentially, there’s a mapping between categories and channels.

I dont know if I make myself clear. Thanks anyway for your explanation again.

Best wishes