Pytorch equivalent of np.partition()

Hi

I need helping write a fast version of np.partition() using pytorch. https://docs.scipy.org/doc/numpy-1.14.1/reference/generated/numpy.partition.html.

a = np.array([3, 4, 2, 1]) 
# np.partition returns a quasi-sorted version of the above array with the values '3' in the 2nd entry and '4' in the 3rd entry because that's where those values would be  if the array above was fully sorted
np.partition(a, (2,3))
>[2,1,3,4]

I need to do a similar thing using pytorch. In particular, I am doing a torch.max() over a large 2D array (1D array would be equivalent) and I’d like to get the first and second largest values.

Is there a way to do this more efficiently than calling torch.max() once, masking out the values found, and then calling it again with masks (i.e. something like the below)?

# 2D version of max search over axis=0
x_largest, y_largest = torch.max(array,0)
array[x_largest, y_largest]=-1E10

# find 2nd largest peak
x_2ndlargest, y_2ndlargest = torch.max(array,0)

torch.topk, which is even simpler

1 Like

That’s it, seems to be perfect function! Thanks so much.