Sigmoid/Softmax for Multiclass Segmentation- Colab Attatched

Please find link to the colab notebook. It has been well annotated and hope you can help me!

I am implementing UNET architecture for segmenting input image into 4 classes.

Input 900 grayscale images[900, 1, 256, 256]
Mask/Ground Truth[900, 1, 256, 256]

Each pixel in the mask has one of the following values - 0,1,2,3 (4 classes)

I understand the output shape of my model should be:
[batch_size, nb_classes, height ,width].

In my case model outputs tensor of size = [batch_size, 4, 256, 256]

After training using CrossEntropy Loss, I try to make predictions using: image

Prediction for a single image y_pred has a shape [4,256,256] whereas my mask shape is [1, 256, 256]

How can I compare my predictions and ground truth when the shapes are mismatched.
Ideally my prediction should be [1,256,256] with 4 unique values(nb_class=4)

I read that I am not supposed to use torch.sigmoid() with nn.CrossEntropy loss. What is the right way to get predictions?
Also want to clarify if I have to one hot encode or map my masks.

nn.CrossEntropyLoss expects raw logits as the model output in the shape [batch_size, nb_classes, height, width] for a multi-class segmentation use case, so you should remove the sigmoid.
Also, the target should be a LongTensor in the shape [batch_size, height, width] containing the class indices in the range [0, nb_classes-1].
To get the predicted class labels you could use pred = torch.argmax(output, dim=1).