for example, I have a tensor in shape [N,C,H,W] = [1,3,2,2]

Then I apply softmax and argmax to obtain the index:

```
# original tensor
tensor([[[[ 0.4008, -0.6662],
[-0.4133, 1.3639]],
[[-0.8354, 0.6317],
[ 0.3240, -1.1438]],
[[-0.3452, 1.2110],
[ 0.6575, 0.9924]]]])
# after softmax
tensor([[[[0.5666, 0.0893],
[0.1664, 0.5646]],
[[0.1646, 0.3270],
[0.3479, 0.0460]],
[[0.2687, 0.5837],
[0.4856, 0.3894]]]])
# after argmax on channel dimension
tensor([[[0, 2],
[2, 0]]])
```

then i want to use the index returned by argmax and convert it into a binary matrix:

```
tensor([[[[1, 0],
[0, 1]], # channel 0
[[0,0],
[0,0]], # channel 1
[[0,1],
[1,0] # channel 2
]]])
```

how to do this?