I wish to use autograd to track gradients through a matrix exponentiation. I initially tried to use torch.linalg.matrix_exp
, but it runs too slowly for my purposes, so instead, I am exponentiating the matrix by eigendecomposition (diagonalizing and raising the diagonal matrix to the appropriate power).
Lambda, V = torch.linalg.eig(self.Xi)
V = V.to(dtype=torch.cdouble)
Lambda = Lambda.to(dtype=torch.cdouble)
V_inv = torch.linalg.inv(V)
exp_Lambda_t = torch.diag_embed(torch.exp(Lambda.unsqueeze(0) * t.unsqueeze(1)))
z0_c = self.z0.to(dtype=torch.cdouble).unsqueeze(-1)
V_inv_z0 = V_inv @ z0_c
V_inv_z0 = V_inv_z0.unsqueeze(0).expand(len(t), -1, -1)
z_t = torch.matmul(V, torch.matmul(exp_Lambda_t, V_inv_z0))
z_t = z_t.squeeze(-1)
z_t=torch.abs(z_t)
Theoretically, the result of this computation should be a real matrix that does not depend on the phase of the chosen eigenvectors, but PyTorch throws RuntimeError: linalg_eig_backward: The eigenvectors in the complex case are specified up to multiplication by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined.
Interestingly, this error only happens sporadically.
Printing out the first row of the offending matrix z_t
, we see that the entries are not actually fully real, and due to some floating-point imprecision, it has nonzero imaginary parts. I tried to sidestep this issue by taking the complex norm of each entry, but the RuntimeError
remained.
[-49.61000061-1.98751675e-15j -1.63878882-1.43775610e-14j
-0.65327233-5.49382679e-15j -0.26159999+2.13188759e-15j
1.57380724-7.87235174e-16j -0.06189372-4.70826694e-15j]
Is there any way to reduce the floating-point imprecision, or sidestep the error thrown by Pytorch and force it to continue? Or are there better matrix exponentiation methods that I should use instead? I have also been seriously considering downgrading to torch<=1.10
to avoid the error.
Thanks in advance.