Propagate gradients through torch.kron?

Hi! I’m wondering if someone has insight on the following issue. My problem is more complicated but I’ve tried to boil it down to the essential. Suppose that $\theta$ is an $n \times n$ matrix and I have a loss function that looks like:

L(theta) = trace{ V(theta) (X - theta) (X - theta)^T } - (n/2) log det(V(theta))

However, V(theta) is itself a function of theta. Specifically, V is defined implicitly as solution to the discrete-time Lyapunov equation V(\theta) = \theta V(\theta) \theta^T + Q for fixed Q.

I know several methods, of varying degrees of efficiency, for finding the V(\theta) that solves the discrete time Lyapunov equation for fixed \theta. The question is: what’s a way to do it so that gradients will propagate appropriately through this solution? The only way I can see is to use the fact that one can use properties of the vec operator and Kronecker product to write:
(I - \theta \oprod \theta^T) \text{vec}(V(\theta) = \text{vec}(Q)

So using torch.kron, torch.solve, and torch.reshape, I can express V(\theta) using torch functions and (presumably) get gradient propagation.

My questions are two fold:

  1. Concerning the API of torch.kron: it’s not clear to me, but the way it’s phrased makes me worry that gradients only propagate through the first argument. I’m hopeful this is wrong! In other words, if I call torch.kron(theta, torch.transpose(theta)), will gradients propagate correctly?
  2. This method is O(n^6) to solve for $V(\theta)$. Ouch. Does anyone know how to get one of the more efficient ways of solving the DT Lyapunov equation to play nicely with torch autograd? Basically an autograd-friendly version of scipy.linalg.solve_discrete_lyapunov?

Thanks for your time and thought!