Convert torch tensor to boolean according to max values

Hi there,

I have a very newby question …

I have a torch tensor of float values as:

torch.tensor([[1., 2., 3.], [1., 3., 2.]])

From it, I want to create a mask vector of 0 and 1 where 1 is the max value of the row:

torch.tensor([[0,0,1],[0,1,0]])

I can use the indices returned by torch.max(dim=-1,…) iterate over them and write 1 but I want to know if exists a more fast way to create such mask vectors.

Thank you!

Using the indices of torch.max sounds like a good idea, but you wouldn’t have to iterate them and could use them directly:

x = torch.tensor([[1., 2., 3.], [1., 3., 2.]])
y = torch.zeros_like(x)
y[torch.arange(y.size(0)), x.argmax(dim=1)] = 1.
print(y)
> tensor([[0., 0., 1.],
          [0., 1., 0.]])
1 Like