# How to efficiently get the K-th largest values in parallel?

I have a tensor `data` with a shape of `[b, c]`, and a tensor `K` with a shape of `[b]`. How can I efficiently get the K-th largest values in parallel? The following code shows what I need:

``````import torch

b, c = 10, 6
data = torch.randn(b, c)
K = torch.randint(low=1, high=c, size=(b,))
result = []
for i in range(b):
result.append(data[i].kthvalue(int(K[i])).values)
result = torch.tensor(result)

print(data)
print(K)
print(result)
``````

Since the two APIs `topk` and `kthvalue` only require an integer `k` as the ranking parameter, I do not know how to implement the above parallel version efficiently.

P.S.

I am using the above process in the `forward` method of my neural network. So an efficient implementation is really appealing for me.

And now I am using the above trivial implementation, which may become slow when the batch size `b` is large. I think some version of `gather` can achieve this (note the off-by-one indexing change):

``````import torch
import time

b, c = 10, 6
data = torch.randn(b, c)
K = torch.randint(low=1, high=c, size=(b,))
result = []

t1 = time.time()
for i in range(b):
result.append(data[i].kthvalue(int(K[i])).values)
result = torch.tensor(result)
t2 = time.time()

result2 = torch.gather(torch.sort(data), 1, (K-1).unsqueeze(1)).squeeze()
t3 = time.time()

print(data)
print(K)
print(result)
print(result2)
print(t2 - t1, t3 - t2)
``````
``````tensor([[-0.1341,  0.0766, -1.1398, -0.9177,  1.4444,  0.2596],
[ 1.5937,  0.2464, -0.5563, -0.5010, -0.5085,  0.0087],
[-1.2250, -0.9519, -0.1867,  0.1060, -1.1537, -0.5210],
[-0.0450, -1.7052, -0.4110,  0.2150, -0.1462,  1.1294],
[-1.4275, -0.0268, -1.6176,  0.4198,  0.4801,  1.3176],
[ 1.0741, -0.6166,  0.3744, -1.0931, -0.1258, -0.0549],
[-1.3071, -0.7423,  0.3884,  0.6894, -1.7652,  0.9175],
[ 0.5899,  1.5174,  1.0743,  0.5756,  0.4595, -1.4915],
[ 1.7524, -0.4349, -0.7347,  0.2961, -0.1927, -0.7609],
[-1.1059, -0.0274, -1.3685, -0.3434, -0.8812,  0.3865]])
tensor([1, 5, 3, 4, 5, 4, 5, 5, 1, 1])
tensor([-1.1398,  0.2464, -0.9519, -0.0450,  0.4801, -0.0549,  0.6894,  1.0743,
-0.7609, -1.3685])
tensor([-1.1398,  0.2464, -0.9519, -0.0450,  0.4801, -0.0549,  0.6894,  1.0743,
-0.7609, -1.3685])
0.00020599365234375 0.00012087821960449219
``````
1 Like