Argmax for semantic segmentation


I have a tensor representing multi class semantic segmentation that is the output of my network. It is of the shape [B, C, H, W]

Where B is the batch size, C is the number of classes, H is the image height and W is the image width. I want to get a one hot vector for each class for each pixel (for each image in the batch).

How can I accomplish this with torch.argmax?

Thank you so much!

In case your current target shape is [batch_size, c, h, w], try to convert it using:

target = torch.argmax(target, 1)

Please find link: Semantic segmentation loss function / shape of prediction and target

When I try this the shape becomes [batch_size, h, w], when what I want is [batch_size, c, h, w] but c is one hot encoded. Any ideas there?