How to view prediction output as Segmented Image in MultiClass Semantic Segmentation?

I trained a UNet using nn.CrossEntropyLoss(). Since the total labels were 32 and batch size was 2, I got output prediction of shape (2,32,388,388). How do I get to see the segmented image from this tensor? Images dont have 32 channels!

Argmax on the second dimension will get you the most probable class per pixel.

Thanks, it solves the problem

A question related to this is what it the best way to do remapping of classes to the original rgb color?