for example,

```
[[1. 0. 1. 0. 0. 0.],
[0. 0. 1. 0. 1. 0.]]
```

and i want get

```
[[[1. 0. 0. 0. 0. 0.],
[0. 0. 1. 0. 0. 0.]],
[[0. 0. 1. 0. 0. 0.],
[0. 0. 0. 0. 1. 0.]]]
```

what is the elegant way to achieve this operation in pytorch?

for example,

```
[[1. 0. 1. 0. 0. 0.],
[0. 0. 1. 0. 1. 0.]]
```

and i want get

```
[[[1. 0. 0. 0. 0. 0.],
[0. 0. 1. 0. 0. 0.]],
[[0. 0. 1. 0. 0. 0.],
[0. 0. 0. 0. 1. 0.]]]
```

what is the elegant way to achieve this operation in pytorch?

Hi Yu!

Use `nonzero()`

to find locations of the `1`

s in your multi-hot tensor,

and then (after some fiddling to get the indices in the right form) use

`one_hot()`

to produce your desired one-hot result:

```
>>> import torch
>>> torch.__version__
'1.10.0'
>>> multihot = torch.tensor ([[1, 0, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0]])
>>> count = multihot[0].sum() # number of ones in first row >>> assert (multihot.sum (dim = 1) == count).all() # verify same number of ones in each row
>>> onehot = torch.nn.functional.one_hot (multihot.nonzero()[:, 1].view (multihot.size (0), count))
>>> onehot
tensor([[[1, 0, 0, 0, 0],
[0, 0, 1, 0, 0]],
[[0, 0, 1, 0, 0],
[0, 0, 0, 0, 1]]])
```

Best.

K. Frank

Hi K.Frank!

Thanks for your reply.

But i notice that the shape of multihot is 2, 6, but the one hot is 2, 2, 5.

the final dim is not equl in above.

Best,

Yu

Hi Yu!

Yes, you are correct. That was my oversight.

We can explicitly tell `one_hot()`

the number of classes, rather than

having it “guess:”

```
>>> onehot = torch.nn.functional.one_hot (multihot.nonzero()[:, 1].view (multihot.size (0), count), num_classes = multihot.size (1))
>>> onehot
tensor([[[1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0]],
[[0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0]]])
```

Note the use of `num_classes = multihot.size (1)`

.

Best.

K. Frank