How to convert argmax result to an one-hot matrix?

for example, I have a tensor in shape [N,C,H,W] = [1,3,2,2]
Then I apply softmax and argmax to obtain the index:

# original tensor
tensor([[[[ 0.4008, -0.6662],
          [-0.4133,  1.3639]],

         [[-0.8354,  0.6317],
          [ 0.3240, -1.1438]],

         [[-0.3452,  1.2110],
          [ 0.6575,  0.9924]]]])

# after softmax
tensor([[[[0.5666, 0.0893],
          [0.1664, 0.5646]],

         [[0.1646, 0.3270],
          [0.3479, 0.0460]],

         [[0.2687, 0.5837],
          [0.4856, 0.3894]]]])

# after argmax on channel dimension
tensor([[[0, 2],
         [2, 0]]])

then i want to use the index returned by argmax and convert it into a binary matrix:

tensor([[[[1, 0],
          [0, 1]],  # channel 0

         [[0,0],
          [0,0]],  # channel 1

         [[0,1],
          [1,0]   # channel 2
]]])

how to do this?

This should work:

x = torch.tensor([[[[ 0.4008, -0.6662],
                    [-0.4133,  1.3639]],

                   [[-0.8354,  0.6317],
                    [ 0.3240, -1.1438]],

                   [[-0.3452,  1.2110],
                    [ 0.6575,  0.9924]]]])

pred = torch.argmax(x, dim=1)
print(pred)
> tensor([[[0, 2],
           [2, 0]]])

out = torch.zeros_like(x).scatter_(1, pred.unsqueeze(1), 1.)
print(out)
> tensor([[[[1., 0.],
            [0., 1.]],

           [[0., 0.],
            [0., 0.]],

           [[0., 1.],
            [1., 0.]]]])
1 Like

Much appreciate! it works

Hi Yun and @ptrblck!

As an alternative, you could use torch.nn.functional.one_hot(),
but would now have to permute() the dimensions of the result to get
them back in the order you want them:

>>> import torch
>>> torch.__version__
'1.9.0'
>>>
>>> t = torch.tensor ([
...     [[[ 0.4008, -0.6662],
...     [-0.4133,  1.3639]],
...
...     [[-0.8354,  0.6317],
...     [ 0.3240, -1.1438]],
...
...     [[-0.3452,  1.2110],
...     [ 0.6575,  0.9924]]]
... ])
>>>
>>> ind = t.argmax (dim = 1)
>>>
>>> ind
tensor([[[0, 2],
         [2, 0]]])
>>>
>>> torch.nn.functional.one_hot (ind).permute (0, 3, 1, 2)
tensor([[[[1, 0],
          [0, 1]],

         [[0, 0],
          [0, 0]],

         [[0, 1],
          [1, 0]]]])

This is arguably a bit more readable, as with “one_hot()” you say
what you mean (at the cost of the permute()).

(Note that, as @ptrblck illustrated, you do not need to use softmax().
Because it doesn’t change the order of the values, leaving it out
doesn’t change the result of argmax().)

Best.

K. Frank

1 Like