I am trying to learn image segmentation in Pytorch. I learned that if I have say 100x100 images and 5 classes , then the ouput of the model will be (Batch_Size, 5, 100 ,100). Also mask images are (Batch_Size,100,100). How do we compare 5-channel output to a mask which does not even have a channel dimension? That would maybe make sense if there was a final argmax in the channel dimension in the architecture so that we can have the predictions for each pixel and then we could compare it with the masks but apparently there is no argmax in any model.
What is the math behind nn.CrossEntropyLoss(output, target) ? Any help would be very appreciated.