Model not predicting any values for certain mask in multi-class segmentation

I have trained a multi-class segmentation model on some data and my data was in the form (4, 320, 320)/(4, 512, 512) which denotes 4 classes with each (1, 512, 512) being a binary segmentation of the one class which is how I’m pretty sure it works, if not please correct me. I got my predictions back and everything looks like its working(albeit the predictions aren’t too good which doesn’t matter to me right now) and I my predictions are also in the form of (4, 512, 512) and I can extract each binary mask with slicing eg. img[1, :, :] and I’m wondering how I can combine all of these masks to make a final mask for the images.

Here is an example of the predictions

and there are 4 of these all in seperate classes that I would like to combine into one image that looks something like this:

Is there a way to achieve this, if so what can I do?

To get the predicted class index for each pixel position you could use:

output # [batch_size, nb_classes, height, width]
pred = torch.argmax(output, dim=1) # [batch_size, height, width]

and map each class index to a color code or directly plot it via e.g. plt.imshow.

I understand the argmax part but how would I map the pixel values to RGB values?

You could either pick a colormap in plt.imshow which would work for you or you could map each class index to an RGB value before visualizing the output.