Multiclass semantic segmentation giving zero in all output channels except for the first channel

I am training a UNet model for multiclass semantic segmentation. My training loop is as follows

for X, y in dataloader:
        X, y = X.to(device), y.to(device).squeeze().long()
        optimizer.zero_grad()
        pred = model(X) # [N, C, H, W]
        loss = loss_fn(pred, y)

I am using nn.CrossEntropyLoss as loss function which requires the input to have shape [N, C, H, W] and target to have shape [N, H, W]. After training the model, I am having 0.37 training loss and 0.32 validation loss. But the prediction that I am getting is like shown below, only the first channel in the prediction has some valid logit while all other channels have 0

array([[[4.554523 , 4.376083 , 4.3859653, ..., 4.385881 , 4.3868537,
         4.5185075],
        [4.372363 , 4.3757625, 4.387993 , ..., 4.387873 , 4.3772345,
         4.3609285],
        [4.3757234, 4.3763394, 4.378887 , ..., 4.3940897, 4.3748627,
         4.374463 ],
        ...,
        [4.377365 , 4.3794756, 4.3913045, ..., 4.3968244, 4.386844 ,
         4.36672  ],
        [4.379352 , 4.3720374, 4.3786154, ..., 4.3863754, 4.376202 ,
         4.368247 ],
        [4.4423575, 4.385245 , 4.3882155, ..., 4.3879733, 4.387912 ,
         4.3996263]],

       [[0.       , 0.       , 0.       , ..., 0.       , 0.       ,
         0.       ],
        [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
         0.       ],
        [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
         0.       ],
        ...,
        [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
         0.       ],
        [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
         0.       ],
        [0.       , 0.       , 0.       , ..., 0.       , 0.       ,
         0.       ]]], dtype=float32)

What should I do?

Is your model returning logits or are you applying e.g. nn.ReLU on the output?

I am applying ReLU after the very last convolution layer that has out_channels = no of classes.

Remove it so your model is returning raw logits.

Thank you, but I have another doubt, should I keep the Batch Normalization layer or should I remove that as well?

Keep all layers and call model.eval() for the validation loop and model.train() before starting the training. This makes sure some layers change their behavior according to training/validation runs. E.g. batchnorm layers will use the running stats during validation runs to normalize the activation inputs.