Set Max value to 1, others to 0

Hello guys,

I’m probably just bad at searching. So pardon me if this is a repost.

I have a tensor with [batch_size, 4]

Andi Want the value of the 2nd dimension to be somehting like [0,0,1,0], where the one corresponds to the max value in this tensor.
Any ideas?

thanks in advance!

Hello Kevin!

I am assuming that you want your result tensor to have the
same shape as your input tensor, [batch_size, 4], with
the largest value in each row replaced by 1 and the lesser
values replaced by 0.

How about (1.e6 * t).softmax (1)?

(No, no, no … Bad idea! Don’t listen to me!!!)

I don’t know of any good, one-step way of doing this. You have
to call argmax() (or do something equivalent) and then “one-hot”
your result. The commonly-suggested way to “one-hot” (that as
far as I know is the best way) is to use scatter().

Try this:

import torch
torch.manual_seed (1414)
t = torch.randn (8, 4)
a = t.argmax (1)
m = torch.zeros (t.shape).scatter (1, a.unsqueeze (1), 1.0)
print ('\n', t, '\n\n', a, '\n\n', m)

Running this, I get

tensor([[ 0.1868,  1.2265,  1.0181,  0.6943],
       [ 0.1075, -0.3540,  0.1766, -0.4940],
       [ 3.0013,  1.8697, -0.0673, -1.3875],
       [-0.0560, -0.6007, -0.0410, -0.6681],
       [-0.2963,  0.1108,  0.2250, -0.5483],
       [-0.1541, -1.1390,  0.4984, -0.4016],
       [-0.3887, -0.1821,  0.4817,  0.4268],
       [ 0.8756,  2.0755,  0.2577,  0.8941]])

tensor([1, 2, 0, 2, 2, 2, 2, 1])

tensor([[0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [1., 0., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.]])

(This is a sufficiently commonly requested manipulation that
I’m surprised that there isn’t a built-in function for it – at
least I haven’t found one.)

Happy One-Hotting!

K. Frank

F.one_hot should yield the same result as the scatter operation :wink: :

x = torch.tensor([1, 2, 0, 2, 2, 2, 2, 1])
F.one_hot(x, num_classes=4)

Hi Peter!

Thanks @ptrblck. one_hot() is just the ticket to replace the
scatter() / unsqueeze() nonsene.

Best.

K. Frank

1 Like

Oops! Spoke too soon.

This github issue:

https://github.com/pytorch/pytorch/issues/15060

suggests that one_hot is there, but is recent, as does this entry
in the documentation:

https://pytorch.org/docs/master/nn.html#one-hot

However, this analogous url for the “stable” documentation
suggests that one_hot hasn’t yet made it to prime time:

“https ://pytorch.org/docs/stable/nn.html#one_hot”

In any event, my 1.0.1 version of pytorch doesn’t have it yet:

>>> import torch
>>> torch.version.__version__
'1.0.1'
>>> torch.nn.functional.one_hot (torch.tensor ([1, 2, 3, 2, 1]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: module 'torch.nn.functional' has no attribute 'one_hot'

(I don’t understand github or the documentation well enough to
be able to figure out in which version one_hot first appeared.)

Best.

K. Frank

Thank you, this is exactly what I was looking for!

I have a follow up question though:
When I use this approach in my custom loss, I get the following error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I guess the auto_grad function has troubles with the one_hot function.
any ideas on how to solve this?

Thanks in advance!

Hi,
I’m not sure if the F.one_hot answer is extendable to the n-dimensional case, I’m also not sure how to do it with reshape operations. @ptrblck

Assume I have a tensor where the first two dims are batch and channel, and the last three correspond to xyz space:

A = torch.randn(b,c,32,32,32)

What I would like to do is to binarize along the x dimension (dim=2) for any batch or channel, i.e for every yz location I want to set the maximum value along x-axis to 1 and the rest to zero. Is there a way of doing this?

Thank you!

If I understand this use case correctly, this should work:

x = torch.randn(2, 3, 4, 4, 4)
idx = torch.argmax(x, dim=2, keepdims=True)
ret = torch.zeros_like(x).scatter_(2, idx, 1.)
1 Like