How to transform a multi-hot tensor to a one-hot tensor?

for example,

[[1. 0. 1. 0. 0. 0.],
[0. 0. 1. 0. 1. 0.]]

and i want get

[[[1. 0. 0. 0. 0. 0.],
[0. 0. 1. 0. 0. 0.]],
[[0. 0. 1. 0. 0. 0.],
[0. 0. 0. 0. 1. 0.]]]

what is the elegant way to achieve this operation in pytorch?

Hi Yu!

Use nonzero() to find locations of the 1s in your multi-hot tensor,
and then (after some fiddling to get the indices in the right form) use
one_hot() to produce your desired one-hot result:

>>> import torch
>>> torch.__version__
'1.10.0'
>>> multihot = torch.tensor ([[1, 0, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0]])
>>> count = multihot[0].sum()   # number of ones in first row                     >>> assert  (multihot.sum (dim = 1) == count).all()   # verify same number of ones in each row
>>> onehot = torch.nn.functional.one_hot (multihot.nonzero()[:, 1].view (multihot.size (0), count))
>>> onehot
tensor([[[1, 0, 0, 0, 0],
         [0, 0, 1, 0, 0]],

        [[0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1]]])

Best.

K. Frank

Hi K.Frank!
Thanks for your reply.
But i notice that the shape of multihot is 2, 6, but the one hot is 2, 2, 5.
the final dim is not equl in above.
Best,
Yu

Hi Yu!

Yes, you are correct. That was my oversight.

We can explicitly tell one_hot() the number of classes, rather than
having it “guess:”

>>> onehot = torch.nn.functional.one_hot (multihot.nonzero()[:, 1].view (multihot.size (0), count), num_classes = multihot.size (1))
>>> onehot
tensor([[[1, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0]],

        [[0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 1, 0]]])

Note the use of num_classes = multihot.size (1).

Best.

K. Frank