Convert output to binary values

a = torch.randn(2,3,2,2, requires_grad=True) #i.e.(batch,chnl,h,w)
amax = torch.argmax(a, dim=1)

print(a) : image
print(amax): image

My aim is to convert the results to binary values: such that value at max index (along dim=1) to be 1 and others to be 0.
For example, based on the above, i want to get this: image

How do i achieve that?
Somewhat felt i am missing on a simple way to do it, but i just couldn’t figure out.

Hi @shadowhy ,

You can try this:

 output = (a == torch.max(a, 1 , keepdim=True)[0])

Let me know if it’s not what you’re looking for.