I have a tensor representing multi class semantic segmentation that is the output of my network. It is of the shape [B, C, H, W]
Where B is the batch size, C is the number of classes, H is the image height and W is the image width. I want to get a one hot vector for each class for each pixel (for each image in the batch).
How can I accomplish this with torch.argmax?
Thank you so much!