How to reverse a multi-hot encoding

I have my data encoded as multi-hot vectors for a multi-label classification task (4-classes in this example):

multihot_batch = torch.tensor([[0,1,0,1], [0,0,0,1], [0,0,1,1]])

How can I undo this encoding and have each entry be a list of the classes present, like this:

tensor([[1, 3],
        [3],
        [2, 3]])

torch.argmax only returns the index of the first max:


torch.argmax(multihot_batch, dim=1, keepdim=True)
tensor([[1],
        [3],
        [2]])

This approach gives them all, but it’s not formatted how I want:

(multihot_batch == torch.tensor(1)).nonzero()
tensor([[0, 1],
        [0, 3],
        [1, 3],
        [2, 2],
        [2, 3]])

I’d appreciate any suggestions.

Hi, it is not easy, after hard thinking for one hour, I work out this problem:

import torch

multihot_batch = torch.tensor([[0,1,0,1], [0,0,0,1], [0,0,1,1]])
#multihot_batch = torch.tensor([[0,1,0,1], [0,0,0,1], [0,0,1,1], [0,0,0,1]])
vnon = (multihot_batch == torch.tensor(1)).nonzero(as_tuple=False)
v0 = vnon[:,0]
v1 = vnon[:,1]

# 0-based index -> 1-based index
split_ind = ((torch.roll(v0, -1, 0) - v0) == 1).nonzero(as_tuple=False)[:,0] + 1
# the first index is excatly the split size of the first split
# the other splits(apart from the last one) can be obtained by this
split_size = torch.cat([split_ind[0].view(1),(torch.roll(split_ind, -1, 0) - split_ind)[:-1]])
# add the final split size
final_size = torch.tensor([torch.numel(v1) - torch.sum(split_size)])
split_size = torch.cat([split_size, final_size])

print(torch.split(v1, split_size.tolist()))

1 Like