The goal of my model is at the end to classify each pixel to the correct (same position) pixel in the label. I have output images of shape [1, 3, 192, 144] for my model, and my labels the shape [1, 192, 144]. I can’t seem to understand how i am going to compute the loss and train my model. Could someone explain it a bit more clearly or point me to some documentation ? (I already searched but no one seems to be doing a training and loss computation compatible).
For now I am doing a crossentropy loss computation but I don’t think it is the correct method.
I think it might be the correct method.
Based on your description it seems you are working on a multi-class segmentation use case, where each pixel belongs to one specific class.
Your model output shape and target shape look good, if you are dealing with 3 classes. The target should contain the class indices in the range [0, nb_classes-1] and you should be fine using nn.CrossEntropyLoss.
Thank you. I am working on a multi-class segmentation like what you described, but I actually have 4 classes. My label is a grayscale (1D for the color channel) with 4 possible values. And my output is of the same size of the image (3 color channels and 192x144 of height and width). Could you please explain why should the output have this exact shape in this case ?
Moreover, when I am loading my label I first used the transform ToTensor() (but it puts everything in the range from 0 to 1) so I changed it to a transform that replaces it with integers form [0, 3] (because there are 4 classes). Is it correct to do it that way ?
In that case the output should have the shape [batch_size, 4, height, width], where each channel corresponds to the logit for a class.
Some transformations on the mask are dangerous, as they might corrupt the mask. ToTensor is such an operation and you should make sure that your mask contains the class indices in the range [0, 3], as you described.
Also, if you are reshaping, I would recommend to use the NEAREST interpolation method, as others might also corrupt the class indices.