Hi,
I am looking for a way to recover the sorted version of a matrix from its argsort indexes in an efficient (pytorch friendly) way:
example:
k = tensor([[2, 9, 3],
[8, 0, 0]])
sorted_k,indx = k.sort()
but running k[indx]
leads to an error as the shape don’t match instead of returning sorted_k
. My best version so far is
torch.cat([row[ind].unsqueeze(0) for ind,row in zip(indx,k)])
But since it involves a for loop, big computation get very slow as I run it. Is there a way to do better?
Related topic
Thanks for your help!