I have a tensor data
with a shape of [b, c]
, and a tensor K
with a shape of [b]
. How can I efficiently get the K-th largest values in parallel?
The following code shows what I need:
import torch
b, c = 10, 6
data = torch.randn(b, c)
K = torch.randint(low=1, high=c, size=(b,))
result = []
for i in range(b):
result.append(data[i].kthvalue(int(K[i])).values)
result = torch.tensor(result)
print(data)
print(K)
print(result)
Since the two APIs topk
and kthvalue
only require an integer k
as the ranking parameter, I do not know how to implement the above parallel version efficiently.
P.S.
I am using the above process in the forward
method of my neural network. So an efficient implementation is really appealing for me.
And now I am using the above trivial implementation, which may become slow when the batch size b
is large.