KL divergence between two distributions with diagonal covariance for VAE - is there an efficient way to do this?

I am training a VAE and want the prior distribution to have non-unit variance, but still be diagonal.

I know that there is this equation:

CodeCogsEqn-3

And for diagonal matrices this simplifies to:

0.5[ \sum_i \log \simga_2^(i) - \sum_i \log \simga_1^(i) + \sum_i \frac{\simga_2^(i)}{\simga_1^(i)} + \sum_i \frac{(\mu_2^(i) = \mu_1^(i))^2)} { \sigma_2^(i)} ]

where \simga_1 and \sigma_2 are diagonal elements of \Simga_1 and \Sigma_2 respectively.

Implementing this naively as:
KL = 0.5 * torch.sum(logVar2 - logVar + torch.exp(logVar) + (mu ** 2 / 2 * sigma2 ** 2) - 0.5)

where the re-param is done like this:
mu(x) + sigma(x) * eps , eps ~ N(0,I)

Is there a more efficient way to do this? Computation time has gone up significantly compared to using the regular KL loss.

Would it make sense to change the re-param like this:
mu(x) + sigma(x) * eps, eps ~ N(0,\Simga_2) ? and use the KL loss that is normally used for unit variance?

Thanks in advanced.

2 Likes