How to get the max value index along one dimension

Hi, I’m doing the image segmentation, at last I get the segmentation map.

If I have 4 classes, I will get a result 4x128x128 suppose that my output segmentation map is 128 x 128

Now I need to record which class has the largest probability so that I can visualize the segmentation.

So I have a 4x128x128 tensor, and I need to know for the first dimension, which index has the maximum value.

For example, I have class label 0 1 2 3, if among the segmenation result a,b,c,d, d has the largest value, so I know for this pixel, the class label is 3. How could I do the max value index?

Could someone help me
Thank you

1 Like

The following code would work


The more detail about torch.max is here

1 Like