Multiclass Image Segmentation setting

Hi, I’m currently working on a multiclass image segmentation project. There are 4 classes (including background 0) in the dataset. My input image loader size is (10,1,244,244), and the output of Unet is (10,4,244,244). As you can see, i set the output channel number as 4 because there are 4 classes. However, the mask (target) size is (10,1,244,244), which has only 1 channel although 4 classes.

I’m using dice_loss from here torchgeometry.losses.dice — PyTorch Geometry documentation,
whose target size should be (B, H, W). Thus i squeeze the first dim of mask(target) so now it’s (10,244,244). To sum up, my image input: (10,1,244,244), network output: (10,4,244,244), Target mask : (10,244,244).

I’m wondering is the setting right for multiclass segmentation? I’m training at about 100 epochs but the training loss is stuck at about 0.55, so i’m thinking probably something is wrong… Thank you!