I have my data encoded as multi-hot vectors for a multi-label classification task (4-classes in this example):
multihot_batch = torch.tensor([[0,1,0,1], [0,0,0,1], [0,0,1,1]])
How can I undo this encoding and have each entry be a list of the classes present, like this:
tensor([[1, 3],
[3],
[2, 3]])
torch.argmax
only returns the index of the first max:
torch.argmax(multihot_batch, dim=1, keepdim=True)
tensor([[1],
[3],
[2]])
This approach gives them all, but it’s not formatted how I want:
(multihot_batch == torch.tensor(1)).nonzero()
tensor([[0, 1],
[0, 3],
[1, 3],
[2, 2],
[2, 3]])
I’d appreciate any suggestions.