Reobtain sorted tensor from argsort


I am looking for a way to recover the sorted version of a matrix from its argsort indexes in an efficient (pytorch friendly) way:


k  = tensor([[2, 9, 3],
             [8, 0, 0]])
sorted_k,indx = k.sort()

but running k[indx] leads to an error as the shape don’t match instead of returning sorted_k. My best version so far is[row[ind].unsqueeze(0) for ind,row in zip(indx,k)])
But since it involves a for loop, big computation get very slow as I run it. Is there a way to do better?
Related topic
Thanks for your help!

I’m not sure I got your point. If you want to know how to use the indexes to sort a tensor, then this snippet could help:

import torch
k  = torch.tensor([[2, 9, 3],
             [8, 0, 0]])
sorted_k, indx = k.sort()
sorted_k_again = k.gather(1, indx)
sorted_k_again.equal(sorted_k) # <- prints True

sorted_k is what you want, if I get this right.


That’s exactly it. Thank you a lot!