The torch.topk returns the topk values and indices of a 2d tensor. How do I set those index values to 1 and everything else to zero?

To give an example, suppose I have a 2d tensor a =tensor([[1,2,3],[6,5,4]]). If I use value,index = torch.topk(a,2,largets=True) then index = tensor([[2,1],[0,1]]). I want the output to be a tensor b = tensor([[0,1,1],[1,1,0]]) without using a for loop.

This should work:

```
b = torch.zeros_like(a)
b[torch.arange(b.size(0)), index.t()] = 1.
print(b)
# tensor([[0, 1, 1],
# [1, 1, 0]])
```