Multiclass Segmentation

Most likely you should not apply any normalization on your segmentation masks, as this will distort the class indices.
Could you print the shape of label before passing it to torch.from_numpy?
I would assume the channel in in dim0 or your images don’t have the channel dimension, if you are loading them with PIL.

Have a look at this post where I’ve explained the mapping a bit better.
Basically your targets should contain class indices in [0, nb_classes].
However, sometimes your segmentation images use a color code for certain classes, e.g. red could be a car and blue could be a tree. Using a mapping, you would have to transform these color codes to class indices, e.g. red->0, blue->1, …

PS: If you are resizing the mask images, make sure to use nearest neighbor interpolation, as other interpolation techniques might distort the labels/colors.

3 Likes