Hi, I am looking for a way to select top-k values of the input N x d
tensor A
, given a N x 1
vector K
whose i-th element indicates the value of k for the i-th row of A
. The output could be a mask tensor with the same shape as A
.
If k is constant, we can simply use torch.topk
, but I wonder if there is any efficient way to do that with K
as a vector?
I tried to find k-th maximum per each row and use it as a threshold, but it does not work as each row could have duplicate values:
kth_max = A.sort(dim=1, descending=True).values.gather(1, K - 1)
mask = A >= kth_max # This is wrong: could select more than k values per row
Any help is really appreciated.
Edit: I overcome the duplicate values by adding a small random noise to elements of A
, but I was wondering if it would be possible to solve this without adding noise?