# How to convert argmax result to an one-hot matrix?

for example, I have a tensor in shape [N,C,H,W] = [1,3,2,2]
Then I apply softmax and argmax to obtain the index:

``````# original tensor
tensor([[[[ 0.4008, -0.6662],
[-0.4133,  1.3639]],

[[-0.8354,  0.6317],
[ 0.3240, -1.1438]],

[[-0.3452,  1.2110],
[ 0.6575,  0.9924]]]])

# after softmax
tensor([[[[0.5666, 0.0893],
[0.1664, 0.5646]],

[[0.1646, 0.3270],
[0.3479, 0.0460]],

[[0.2687, 0.5837],
[0.4856, 0.3894]]]])

# after argmax on channel dimension
tensor([[[0, 2],
[2, 0]]])
``````

then i want to use the index returned by argmax and convert it into a binary matrix:

``````tensor([[[[1, 0],
[0, 1]],  # channel 0

[[0,0],
[0,0]],  # channel 1

[[0,1],
[1,0]   # channel 2
]]])
``````

how to do this?

This should work:

``````x = torch.tensor([[[[ 0.4008, -0.6662],
[-0.4133,  1.3639]],

[[-0.8354,  0.6317],
[ 0.3240, -1.1438]],

[[-0.3452,  1.2110],
[ 0.6575,  0.9924]]]])

pred = torch.argmax(x, dim=1)
print(pred)
> tensor([[[0, 2],
[2, 0]]])

out = torch.zeros_like(x).scatter_(1, pred.unsqueeze(1), 1.)
print(out)
> tensor([[[[1., 0.],
[0., 1.]],

[[0., 0.],
[0., 0.]],

[[0., 1.],
[1., 0.]]]])
``````
1 Like

Much appreciate! it works

Hi Yun and @ptrblck!

As an alternative, you could use `torch.nn.functional.one_hot()`,
but would now have to `permute()` the dimensions of the result to get
them back in the order you want them:

``````>>> import torch
>>> torch.__version__
'1.9.0'
>>>
>>> t = torch.tensor ([
...     [[[ 0.4008, -0.6662],
...     [-0.4133,  1.3639]],
...
...     [[-0.8354,  0.6317],
...     [ 0.3240, -1.1438]],
...
...     [[-0.3452,  1.2110],
...     [ 0.6575,  0.9924]]]
... ])
>>>
>>> ind = t.argmax (dim = 1)
>>>
>>> ind
tensor([[[0, 2],
[2, 0]]])
>>>
>>> torch.nn.functional.one_hot (ind).permute (0, 3, 1, 2)
tensor([[[[1, 0],
[0, 1]],

[[0, 0],
[0, 0]],

[[0, 1],
[1, 0]]]])
``````

This is arguably a bit more readable, as with â€ś`one_hot()`â€ť you say
what you mean (at the cost of the `permute()`).

(Note that, as @ptrblck illustrated, you do not need to use `softmax()`.
Because it doesnâ€™t change the order of the values, leaving it out
doesnâ€™t change the result of `argmax()`.)

Best.

K. Frank

1 Like