I have a batch of diagonal matrices A of size (batch_size, 3, 3), and a batch of eigenvector matrices V of size (batch_size, 3, 3)

I want to sort the collumns of W such that they are in order with respect to their eigenvalues.

What I do now is:

```
diag = torch.cat(
(A[:,0,0].view(-1,1),A[:,1,1].view(-1,1),A[:,2,2].view(-1,1)),
dim=1)
diag = torch.abs(diag)
sort = torch.argsort(diag, dim=1, descending=True)
for i in range(W.size(0)):
W[i] = W[i,sort[i]]
```

But it is SLOOOW!

In particular the for loop takes about 97% of the time.

How can I do this sorting in parallel?

(just FYI I am implementing a batch diagonalization as I donâ€™t belive there exists one. My case of symmetric 3x3 matrices allows for a very easy code)