Semantic segmentation, how does it work?

Im trying to run U-Net. My data set has 3 classes (2 classes + background) so I made the mask which contains class indices for each pixel like [[0,1,2,1,2,0,...,1],[0,1,1,0,0,...,1]...[0,1,1,0,1,...,2]]. The mask shape is (512,512), and class indices are 0, 1, 2 (0 for background). But I wonder how the model works! My model(U-Net)s input channel is 1 (gray image), and output is 3 channels.
I have no idea which loss function do i need to use? why?
And if the answer is CrossEntropyLoss, how does it work despite output channel is 3, mask channel is 1.
Please god help me.
(Or did i do something wrong?)

You should use nn.CrossEntropyLoss and can have a look at the docs to see how the loss formula is implemented.
I would recommend to skip the spatial dimensions for now and just take a look at the standard multi-class classification use case, where the output has the shape [batch_size, nb_classes] and the target contains class indices in the range [0, nb_classes-1] with the shape [batch_size].
As you can see in the docs, the output logits can be directly indexed using the target.

For the multi-class segmentation use case, the same formula yould be applied but for each pixel location, since the model output and the target now have additional [..., height, width] dimensions.

1 Like