Hi! I’m wondering if someone has insight on the following issue. My problem is more complicated but I’ve tried to boil it down to the essential. Suppose that $\theta$ is an $n \times n$ matrix and I have a loss function that looks like:
L(theta) = trace{ V(theta) (X - theta) (X - theta)^T } - (n/2) log det(V(theta))
However, V(theta) is itself a function of theta. Specifically, V is defined implicitly as solution to the discrete-time Lyapunov equation V(\theta) = \theta V(\theta) \theta^T + Q for fixed Q.
I know several methods, of varying degrees of efficiency, for finding the V(\theta) that solves the discrete time Lyapunov equation for fixed \theta. The question is: what’s a way to do it so that gradients will propagate appropriately through this solution? The only way I can see is to use the fact that one can use properties of the vec operator and Kronecker product to write:
$$
(I - \theta \oprod \theta^T) \text{vec}(V(\theta) = \text{vec}(Q)
$$
So using torch.kron, torch.solve, and torch.reshape, I can express V(\theta) using torch functions and (presumably) get gradient propagation.
My questions are two fold:
- Concerning the API of torch.kron: it’s not clear to me, but the way it’s phrased makes me worry that gradients only propagate through the first argument. I’m hopeful this is wrong! In other words, if I call torch.kron(theta, torch.transpose(theta)), will gradients propagate correctly?
- This method is O(n^6) to solve for $V(\theta)$. Ouch. Does anyone know how to get one of the more efficient ways of solving the DT Lyapunov equation to play nicely with torch autograd? Basically an autograd-friendly version of scipy.linalg.solve_discrete_lyapunov?
Thanks for your time and thought!