How to sort matrices in parallel? [Solved]

I have a batch of diagonal matrices A of size (batch_size, 3, 3), and a batch of eigenvector matrices V of size (batch_size, 3, 3)
I want to sort the collumns of W such that they are in order with respect to their eigenvalues.

What I do now is:

diag =
diag = torch.abs(diag)
sort = torch.argsort(diag, dim=1, descending=True)
for i in range(W.size(0)):
        W[i] = W[i,sort[i]]

But it is SLOOOW!
In particular the for loop takes about 97% of the time.
How can I do this sorting in parallel?

(just FYI I am implementing a batch diagonalization as I don’t belive there exists one. My case of symmetric 3x3 matrices allows for a very easy code)

Turns out all I had to do was use torch.gather.

from ~45sec -> ~3sec for the same experiment. Noice.