I am training a UNet model for multiclass semantic segmentation. My training loop is as follows
for X, y in dataloader:
X, y = X.to(device), y.to(device).squeeze().long()
pred = model(X) # [N, C, H, W]
loss = loss_fn(pred, y)
I am using nn.CrossEntropyLoss
as loss function which requires the input to have shape [N, C, H, W]
and target to have shape [N, H, W]
. After training the model, I am having 0.37 training loss and 0.32 validation loss. But the prediction that I am getting is like shown below, only the first channel in the prediction has some valid logit while all other channels have 0
array([[[4.554523 , 4.376083 , 4.3859653, ..., 4.385881 , 4.3868537,
[4.372363 , 4.3757625, 4.387993 , ..., 4.387873 , 4.3772345,
[4.3757234, 4.3763394, 4.378887 , ..., 4.3940897, 4.3748627,
4.374463 ],
[4.377365 , 4.3794756, 4.3913045, ..., 4.3968244, 4.386844 ,
4.36672 ],
[4.379352 , 4.3720374, 4.3786154, ..., 4.3863754, 4.376202 ,
4.368247 ],
[4.4423575, 4.385245 , 4.3882155, ..., 4.3879733, 4.387912 ,
[[0. , 0. , 0. , ..., 0. , 0. ,
0. ],
[0. , 0. , 0. , ..., 0. , 0. ,
0. ],
[0. , 0. , 0. , ..., 0. , 0. ,
0. ],
[0. , 0. , 0. , ..., 0. , 0. ,
0. ],
[0. , 0. , 0. , ..., 0. , 0. ,
0. ],
[0. , 0. , 0. , ..., 0. , 0. ,
0. ]]], dtype=float32)
What should I do?