Using the indices of torch.max
sounds like a good idea, but you wouldn’t have to iterate them and could use them directly:
x = torch.tensor([[1., 2., 3.], [1., 3., 2.]])
y = torch.zeros_like(x)
y[torch.arange(y.size(0)), x.argmax(dim=1)] = 1.
print(y)
> tensor([[0., 0., 1.],
[0., 1., 0.]])