Convert torch tensor to boolean according to max values

Using the indices of torch.max sounds like a good idea, but you wouldn’t have to iterate them and could use them directly:

x = torch.tensor([[1., 2., 3.], [1., 3., 2.]])
y = torch.zeros_like(x)
y[torch.arange(y.size(0)), x.argmax(dim=1)] = 1.
print(y)
> tensor([[0., 0., 1.],
          [0., 1., 0.]])
1 Like