Problem with visualizing multi-class predictions

I ran U-net (with softmax) on Camvid data to predict multi-class segmentation. Seems everything ran well but I am having trouble visualizing the predictions. They look flat as shown below.

Full code can be accessed here.

The output looks like this


helper function for data visualization

def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.title(' '.join(name.split('_')).title())

applying visualize on predictions

image, gt_mask = test_dataset[n]
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)

gt_mask_0 = (gt_mask[...,0].squeeze())  
gt_mask_1 = (gt_mask[...,1].squeeze())

pr_mask = model.predict(x_tensor)
pr_mask_0 = (pr_mask[...,0].squeeze().cpu().numpy().round())   
pr_mask_1 = (pr_mask[...,1].squeeze().cpu().numpy().round())   

    sky_mask = pr_mask_0


array([[[1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.],
        [1., 1., 1., ..., 1., 1., 1.]],


       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)


tensor([[[[0.0089, 0.0096, 0.0101,  ..., 0.0044, 0.0041, 0.0038],
          [0.0094, 0.0098, 0.0099,  ..., 0.0057, 0.0056, 0.0053],
          [0.0097, 0.0097, 0.0094,  ..., 0.0071, 0.0072, 0.0072],
          [0.0126, 0.0129, 0.0128,  ..., 0.0063, 0.0070, 0.0076],
          [0.0113, 0.0116, 0.0115,  ..., 0.0062, 0.0073, 0.0086],
          [0.0101, 0.0103, 0.0102,  ..., 0.0060, 0.0077, 0.0097]],

         [[0.0354, 0.0348, 0.0332,  ..., 0.0047, 0.0061, 0.0078],
          [0.0471, 0.0452, 0.0421,  ..., 0.0052, 0.0068, 0.0086],
          [0.0610, 0.0571, 0.0515,  ..., 0.0054, 0.0072, 0.0094],
          [0.0111, 0.0161, 0.0225,  ..., 0.0364, 0.0368, 0.0369],
          [0.0078, 0.0115, 0.0165,  ..., 0.0303, 0.0342, 0.0382],
          [0.0054, 0.0081, 0.0120,  ..., 0.0251, 0.0316, 0.0394]],

         [[0.0090, 0.0070, 0.0053,  ..., 0.0049, 0.0055, 0.0061],
          [0.0097, 0.0084, 0.0070,  ..., 0.0065, 0.0074, 0.0082],
          [0.0102, 0.0098, 0.0090,  ..., 0.0084, 0.0096, 0.0106],
          [0.0245, 0.0227, 0.0202,  ..., 0.0142, 0.0136, 0.0129],
          [0.0211, 0.0191, 0.0169,  ..., 0.0131, 0.0131, 0.0128],
          [0.0180, 0.0160, 0.0140,  ..., 0.0120, 0.0124, 0.0127]],


         [[0.0259, 0.0252, 0.0239,  ..., 0.0033, 0.0048, 0.0068],
          [0.0253, 0.0244, 0.0228,  ..., 0.0043, 0.0060, 0.0082],
          [0.0242, 0.0230, 0.0211,  ..., 0.0052, 0.0072, 0.0097],
          [0.0636, 0.0476, 0.0343,  ..., 0.0413, 0.0394, 0.0371],
          [0.0640, 0.0479, 0.0348,  ..., 0.0459, 0.0416, 0.0373],
          [0.0639, 0.0479, 0.0352,  ..., 0.0508, 0.0437, 0.0372]],

         [[0.7196, 0.7019, 0.6668,  ..., 0.6152, 0.5556, 0.4922],
          [0.6529, 0.6269, 0.5832,  ..., 0.5199, 0.4687, 0.4144],
          [0.5768, 0.5427, 0.4922,  ..., 0.4186, 0.3824, 0.3403],
          [0.3214, 0.2712, 0.2201,  ..., 0.3356, 0.3038, 0.2719],
          [0.3338, 0.2885, 0.2423,  ..., 0.3139, 0.2817, 0.2503],
          [0.3442, 0.3053, 0.2652,  ..., 0.2915, 0.2599, 0.2290]],

         [[0.0182, 0.0217, 0.0253,  ..., 0.0382, 0.0371, 0.0354],
          [0.0259, 0.0309, 0.0358,  ..., 0.0667, 0.0543, 0.0434],
          [0.0359, 0.0427, 0.0489,  ..., 0.1109, 0.0769, 0.0520],
          [0.0222, 0.0215, 0.0199,  ..., 0.1235, 0.1231, 0.1213],
          [0.0250, 0.0241, 0.0225,  ..., 0.1325, 0.1389, 0.1441],
          [0.0278, 0.0268, 0.0253,  ..., 0.1412, 0.1559, 0.1701]]]],

Could you post the shapes of the masks you are passing to visualize?

# check shapes 
images, masks = next(iter(train_loader))

# %%

# %%
images, masks = next(iter(test_loader))

# %%
#[batch_size, no_of_classes, size, size]

torch.Size([8, 3, 320, 320])
torch.Size([8, 12, 320, 320])
torch.Size([8, 3, 384, 480])
torch.Size([8, 12, 384, 480])

I figured out the solution. This worked for me though I had to display masks in different figures.

1 Like

I am facing the same problem. How could you merge all masks then? This is what I got:

I could merge all masks in one image but only as follows:

But When I tried to do that on the prediction images it did not work?

I used the following code to do so:
test_dataset_vis = Dataset(x_test_dir, y_test_dir, classes=[‘sky’, ‘building’, ‘pole’, ‘road’]) # For multi class

image, mask = test_dataset_vis[10] # get some sample

visualize( image=image, ground_truth_many_mask = mask[…,0:3].squeeze(), )

Could you share the code you’ve used to merge the class indices from the ground truth labels as I would assume the same one would also work on your predictions. Could you also explain which issue you are currently facing while trying to visualize the predictions?

Thank you for the quick response. This is what I used:

When I tried to use the same code on the prediction data I get this:

Your mask tensors are in channels-first format while you are trying to slice them in the last dimension which creates these small outputs.
I don’t know if you are dealing with 4 classes of if the channels represent RGBA, but in any case you need to permute the tensors such that they are represented in channels-last.

Thank you. based on your comment the results became better. I made the conversion (CHW → HWC). But I still can’t get all the classes on the predicted make? I appreciate it if you have a solution for that.

And I still need to read about "permute the tensors"

I think the channels represent RGBA.