How to efficiently get the K-th largest values in parallel?

I have a tensor data with a shape of [b, c], and a tensor K with a shape of [b]. How can I efficiently get the K-th largest values in parallel? :thinking:

The following code shows what I need:

import torch

b, c = 10, 6
data = torch.randn(b, c)
K = torch.randint(low=1, high=c, size=(b,))
result = []
for i in range(b):
    result.append(data[i].kthvalue(int(K[i])).values)
result = torch.tensor(result)

print(data)
print(K)
print(result)

Since the two APIs topk and kthvalue only require an integer k as the ranking parameter, I do not know how to implement the above parallel version efficiently.


P.S.

I am using the above process in the forward method of my neural network. So an efficient implementation is really appealing for me.

And now I am using the above trivial implementation, which may become slow when the batch size b is large. :sweat_smile:

I think some version of gather can achieve this (note the off-by-one indexing change):

import torch
import time

b, c = 10, 6
data = torch.randn(b, c)
K = torch.randint(low=1, high=c, size=(b,))
result = []

t1 = time.time()
for i in range(b):
    result.append(data[i].kthvalue(int(K[i])).values)
result = torch.tensor(result)
t2 = time.time()

result2 = torch.gather(torch.sort(data)[0], 1, (K-1).unsqueeze(1)).squeeze()
t3 = time.time()

print(data)
print(K)
print(result)
print(result2)
print(t2 - t1, t3 - t2)
tensor([[-0.1341,  0.0766, -1.1398, -0.9177,  1.4444,  0.2596],
        [ 1.5937,  0.2464, -0.5563, -0.5010, -0.5085,  0.0087],
        [-1.2250, -0.9519, -0.1867,  0.1060, -1.1537, -0.5210],
        [-0.0450, -1.7052, -0.4110,  0.2150, -0.1462,  1.1294],
        [-1.4275, -0.0268, -1.6176,  0.4198,  0.4801,  1.3176],
        [ 1.0741, -0.6166,  0.3744, -1.0931, -0.1258, -0.0549],
        [-1.3071, -0.7423,  0.3884,  0.6894, -1.7652,  0.9175],
        [ 0.5899,  1.5174,  1.0743,  0.5756,  0.4595, -1.4915],
        [ 1.7524, -0.4349, -0.7347,  0.2961, -0.1927, -0.7609],
        [-1.1059, -0.0274, -1.3685, -0.3434, -0.8812,  0.3865]])
tensor([1, 5, 3, 4, 5, 4, 5, 5, 1, 1])
tensor([-1.1398,  0.2464, -0.9519, -0.0450,  0.4801, -0.0549,  0.6894,  1.0743,
        -0.7609, -1.3685])
tensor([-1.1398,  0.2464, -0.9519, -0.0450,  0.4801, -0.0549,  0.6894,  1.0743,
        -0.7609, -1.3685])
0.00020599365234375 0.00012087821960449219
1 Like