I have a batch of diagonal matrices A of size (batch_size, 3, 3), and a batch of eigenvector matrices V of size (batch_size, 3, 3)
I want to sort the collumns of W such that they are in order with respect to their eigenvalues.
What I do now is:
diag = torch.cat(
(A[:,0,0].view(-1,1),A[:,1,1].view(-1,1),A[:,2,2].view(-1,1)),
dim=1)
diag = torch.abs(diag)
sort = torch.argsort(diag, dim=1, descending=True)
for i in range(W.size(0)):
W[i] = W[i,sort[i]]
But it is SLOOOW!
In particular the for loop takes about 97% of the time.
How can I do this sorting in parallel?
(just FYI I am implementing a batch diagonalization as I don’t belive there exists one. My case of symmetric 3x3 matrices allows for a very easy code)