How to accelerate this vanilla program with 2 torch.topk()?

What I want to do is:

given a tensor with shape like [32,100] as input, I want to get 2 tensor with shape [32,50] and [32,50] as output, where the first [32,50] contains the top-50-largest elements in each rows and the second [32,50] contains the top-50-smallest.

So far, I know I can use torch.topk() twice to get the desired two tensors, but I assume it isn’t efficient enough (since topk is time-consuming) and I think there should be someway to directly get the second tensor by the result of first topk.

All in all, I want to know what is the best practice for implement program above with only one calling of torch.topk()?

Hi Fried!

In this specific use case you want, at the end, all 100 elements split up
a certain way. So you will be better off calling sort() once, rather than
calling topk() one or more times:

>>> import torch
>>> print (torch.__version__)
>>> _ = torch.manual_seed (2022)
>>> t = torch.randn  (32, 100)
>>> tsort = t.sort (dim = 1)[0]
>>> topbig = tsort[:, 50:].flip (1)
>>> topsml = tsort[:, :50]
>>> topbig[0, :5]
tensor([3.0777, 2.6153, 2.5442, 2.3476, 1.8371])
>>> topsml[0, :5]
tensor([-2.2356, -2.2051, -2.1979, -2.1629, -1.8965])


K. Frank

Stupid for me. Definitely works for me. Thank you a lot.