I have a tensor with size N*K with float numbers. I want to change the values of top K in every row of it to 1 and change all other values to 0. What is the best way for doing so in pytorch?
thanks.
I assume the second K
is smaller than the first one.
This code should work:
x = torch.randint(0, 10, (10, 5))
k = 3
kvals, kidx = x.topk(k=k, dim=1)
x.zero_()
x[torch.arange(x.size(0))[:, None], kidx] = 1
1 Like
You’re amazing! Thanks.
1 Like