How to visualize semantic segmentation maps

Hello!

I have the output of a semantic segmentation network of the shape [21, 144, 256] and I would like to change it to an image of shape [3, 144, 256] to visualize it with a colour for each class.

The way that images are saved here works, but I would like to have the visualized image as a tensor so I can concatenate it with the ground truth or torchvision.utils.save_image a whole batch.

Thank you!

Assuming you are dealing with 21 classes and the current output contains the logits, you could get the predicted class indices via preds = torch.argmax(output, dim=0), where output would have the shape [21, 144, 256]. Once this is done you could then map the class indices to colors and transform the prediction to an RGB image.