Change values of top K in every row of tensor

I have a tensor with size N*K with float numbers. I want to change the values of top K in every row of it to 1 and change all other values to 0. What is the best way for doing so in pytorch?
thanks.

I assume the second K is smaller than the first one.
This code should work:

x = torch.randint(0, 10, (10, 5))
k = 3
kvals, kidx = x.topk(k=k, dim=1)
x.zero_()
x[torch.arange(x.size(0))[:, None], kidx] = 1
1 Like

You’re amazing! Thanks.

1 Like