I’ve searched a lot but with no significant result so far so I’m posting in this forum

I have a Tensor in the form of :
(0 ,.,.) =
1.0858 -3.2292 -0.6477

(1 ,.,.) =
3.3863 -2.3600 -2.6776

(2 ,.,.) =
-2.3978 -2.0387 2.0728

and i want to transform it in a way that all the max values of each row will get the numer 1 and the others the number 0.
Any help would be very apprieciated

if temp is your tensor this line do what you want
mask=(temp == temp.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
and if you want also value of max this is
temp=torch.mul(mask,temp)

Hi,
I guess I’m failing to see how to extend this to the n-dimensional case. @richard

Assume I have a tensor where the first two dims are batch and channel, and the last three correspond to xyz space:

A = torch.randn(b,c,32,32,32)

What I would like to do is to binarize along the x dimension (dim=2) for any batch or channel, i.e for every yz location I want to set the maximum value along x-axis to 1 and the rest to zero. Is there a way of doing this?