Im trying to run U-Net. My data set has 3 classes (2 classes + background) so I made the mask which contains class indices for each pixel like [[0,1,2,1,2,0,...,1],[0,1,1,0,0,...,1]...[0,1,1,0,1,...,2]]. The mask shape is (512,512), and class indices are 0, 1, 2 (0 for background). But I wonder how the model works! My model(U-Net)
s input channel is 1 (gray image), and output is 3 channels.
I have no idea which loss function do i need to use? why?
And if the answer is CrossEntropyLoss, how does it work despite output channel is 3, mask channel is 1.
Please god help me.
(Or did i do something wrong?)
You should use nn.CrossEntropyLoss
and can have a look at the docs to see how the loss formula is implemented.
I would recommend to skip the spatial dimensions for now and just take a look at the standard multi-class classification use case, where the output has the shape [batch_size, nb_classes]
and the target contains class indices in the range [0, nb_classes-1]
with the shape [batch_size]
.
As you can see in the docs, the output logits can be directly indexed using the target.
For the multi-class segmentation use case, the same formula yould be applied but for each pixel location, since the model output and the target now have additional [..., height, width]
dimensions.
1 Like