How to select top-k values with different k per row

Hi, I am looking for a way to select top-k values of the input N x d tensor A, given a N x 1 vector K whose i-th element indicates the value of k for the i-th row of A. The output could be a mask tensor with the same shape as A.

If k is constant, we can simply use torch.topk, but I wonder if there is any efficient way to do that with K as a vector?

I tried to find k-th maximum per each row and use it as a threshold, but it does not work as each row could have duplicate values:

kth_max = A.sort(dim=1, descending=True).values.gather(1, K - 1)
mask = A >= kth_max  # This is wrong: could select more than k values per row

Any help is really appreciated.

Edit: I overcome the duplicate values by adding a small random noise to elements of A, but I was wondering if it would be possible to solve this without adding noise?

3 Likes