I have a tensor
data with a shape of
[b, c], and a tensor
K with a shape of
[b]. How can I efficiently get the K-th largest values in parallel?
The following code shows what I need:
import torch b, c = 10, 6 data = torch.randn(b, c) K = torch.randint(low=1, high=c, size=(b,)) result =  for i in range(b): result.append(data[i].kthvalue(int(K[i])).values) result = torch.tensor(result) print(data) print(K) print(result)
Since the two APIs
kthvalue only require an integer
k as the ranking parameter, I do not know how to implement the above parallel version efficiently.
I am using the above process in the
forward method of my neural network. So an efficient implementation is really appealing for me.
And now I am using the above trivial implementation, which may become slow when the batch size
b is large.