a = torch.randn(2,3,2,2, requires_grad=True) #i.e.(batch,chnl,h,w)
amax = torch.argmax(a, dim=1)
print(a) :
print(amax):
My aim is to convert the results to binary values: such that value at max index (along dim=1) to be 1 and others to be 0.
For example, based on the above, i want to get this:
How do i achieve that?
Somewhat felt i am missing on a simple way to do it, but i just couldn’t figure out.