Problem with visualizing multi-class predictions

I ran U-net (with softmax) on Camvid data to predict multi-class segmentation. Seems everything ran well but I am having trouble visualizing the predictions. They look flat as shown below.

Full code can be accessed here.

The output looks like this

.

helper function for data visualization

def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

applying visualize on predictions

image, gt_mask = test_dataset[n]
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)

gt_mask_0 = (gt_mask[...,0].squeeze())  
gt_mask_1 = (gt_mask[...,1].squeeze())

pr_mask = model.predict(x_tensor)
pr_mask_0 = (pr_mask[...,0].squeeze().cpu().numpy().round())   
pr_mask_1 = (pr_mask[...,1].squeeze().cpu().numpy().round())   

visualize(
    image=image_vis, 
    ground_truth_mask=gt_mask_0,
    sky_mask = pr_mask_0
)

gt_mask

array([[[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)

pr_mask

tensor([[[[0.0089, 0.0096, 0.0101,  ..., 0.0044, 0.0041, 0.0038],
          [0.0094, 0.0098, 0.0099,  ..., 0.0057, 0.0056, 0.0053],
          [0.0097, 0.0097, 0.0094,  ..., 0.0071, 0.0072, 0.0072],
          ...,
          [0.0126, 0.0129, 0.0128,  ..., 0.0063, 0.0070, 0.0076],
          [0.0113, 0.0116, 0.0115,  ..., 0.0062, 0.0073, 0.0086],
          [0.0101, 0.0103, 0.0102,  ..., 0.0060, 0.0077, 0.0097]],

         [[0.0354, 0.0348, 0.0332,  ..., 0.0047, 0.0061, 0.0078],
          [0.0471, 0.0452, 0.0421,  ..., 0.0052, 0.0068, 0.0086],
          [0.0610, 0.0571, 0.0515,  ..., 0.0054, 0.0072, 0.0094],
          ...,
          [0.0111, 0.0161, 0.0225,  ..., 0.0364, 0.0368, 0.0369],
          [0.0078, 0.0115, 0.0165,  ..., 0.0303, 0.0342, 0.0382],
          [0.0054, 0.0081, 0.0120,  ..., 0.0251, 0.0316, 0.0394]],

         [[0.0090, 0.0070, 0.0053,  ..., 0.0049, 0.0055, 0.0061],
          [0.0097, 0.0084, 0.0070,  ..., 0.0065, 0.0074, 0.0082],
          [0.0102, 0.0098, 0.0090,  ..., 0.0084, 0.0096, 0.0106],
          ...,
          [0.0245, 0.0227, 0.0202,  ..., 0.0142, 0.0136, 0.0129],
          [0.0211, 0.0191, 0.0169,  ..., 0.0131, 0.0131, 0.0128],
          [0.0180, 0.0160, 0.0140,  ..., 0.0120, 0.0124, 0.0127]],

         ...,

         [[0.0259, 0.0252, 0.0239,  ..., 0.0033, 0.0048, 0.0068],
          [0.0253, 0.0244, 0.0228,  ..., 0.0043, 0.0060, 0.0082],
          [0.0242, 0.0230, 0.0211,  ..., 0.0052, 0.0072, 0.0097],
          ...,
          [0.0636, 0.0476, 0.0343,  ..., 0.0413, 0.0394, 0.0371],
          [0.0640, 0.0479, 0.0348,  ..., 0.0459, 0.0416, 0.0373],
          [0.0639, 0.0479, 0.0352,  ..., 0.0508, 0.0437, 0.0372]],

         [[0.7196, 0.7019, 0.6668,  ..., 0.6152, 0.5556, 0.4922],
          [0.6529, 0.6269, 0.5832,  ..., 0.5199, 0.4687, 0.4144],
          [0.5768, 0.5427, 0.4922,  ..., 0.4186, 0.3824, 0.3403],
          ...,
          [0.3214, 0.2712, 0.2201,  ..., 0.3356, 0.3038, 0.2719],
          [0.3338, 0.2885, 0.2423,  ..., 0.3139, 0.2817, 0.2503],
          [0.3442, 0.3053, 0.2652,  ..., 0.2915, 0.2599, 0.2290]],

         [[0.0182, 0.0217, 0.0253,  ..., 0.0382, 0.0371, 0.0354],
          [0.0259, 0.0309, 0.0358,  ..., 0.0667, 0.0543, 0.0434],
          [0.0359, 0.0427, 0.0489,  ..., 0.1109, 0.0769, 0.0520],
          ...,
          [0.0222, 0.0215, 0.0199,  ..., 0.1235, 0.1231, 0.1213],
          [0.0250, 0.0241, 0.0225,  ..., 0.1325, 0.1389, 0.1441],
          [0.0278, 0.0268, 0.0253,  ..., 0.1412, 0.1559, 0.1701]]]],
       device='cuda:0')

Could you post the shapes of the masks you are passing to visualize?

# check shapes 
images, masks = next(iter(train_loader))
print(images.shape)
print(masks.shape)

# %%
print(len(train_loader))

# %%
images, masks = next(iter(test_loader))
print(images.shape)
print(masks.shape)

# %%
print((len(test_loader)))
#[batch_size, no_of_classes, size, size]

torch.Size([8, 3, 320, 320])
torch.Size([8, 12, 320, 320])
46
torch.Size([8, 3, 384, 480])
torch.Size([8, 12, 384, 480])
13

I figured out the solution. This worked for me though I had to display masks in different figures.

    visualize(
        image=denormalize(image_vis.squeeze()),
        gt_mask_car=gt_mask[0].squeeze(),
        pr_mask_car=pr_mask[0].squeeze(),
        gt_mask_pedestrian=gt_mask[1].squeeze(),
        pr_mask_pedestrain=pr_mask[1].squeeze()
    )