Does this function already exist?

So I have and (N,3) shape tensor, I would like to have the each of the (1,3) shaped rows to be converted to a (1,3) shaped row with a 1 in the place that the maximum of the row is and a 0 in the other place.

So is there a function that changes the maximum entry of a tensor to a 1 and the rest to 0 along a given axis, not necessarily a N,3 tensor?

for example

Tensor([[2.3, 3.1, 1.2],
[5.6. 4.1, 3.1]])

would go to

Tensor([0, 1, 0],
[1, 0, 0])

I’m sure the torch.sort() function on that dimension will do that just fine :slight_smile:

Sorry, I think my wording was bad, I made an edit to hopefully clear things up.

This code snippet should work:

x = torch.tensor([[2.3, 3.1, 1.2],
                  [5.6, 4.1, 3.1]])
idx = torch.argmax(x, 1)
output = torch.zeros_like(x).scatter_(1, idx[:, None], 1.)

Thank you so much I’ve been stuck on this for hours. Worked like a charm.

1 Like