KL divergence for (Multivariate)Normal with diagonal covariance matrix

I use torch.distributions for variational inference and thus use torch.distributions.kl_divergence to compute analytical KL divergences. However, there is no registered KL divergence for normal distributions with diagonal covariance when they are defined by wrapping Normal with Independent. The code example below shows this in detail.

The KL divergence can be correctly computed if we use the MultivariateNormal class to instantiate the covariance matrix, but this is rather inefficient both in terms of memory and compute time when the covariance matrix is diagonal.

Is there any other way to go here or do I simply have to implement the analytical formula myself?

import torch
# Define diagonal normal with MultivariateNormal (to compute KL wrt.)
mu = torch.Tensor([0, 0])
scale1 = torch.Tensor([1, 1])
cov1 = torch.diag_embed(scale1)
diag_normal_mv1 = torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=cov1)

# Define diagonal normal with Independent
scale2 = torch.Tensor([0.5, 1.5])
diag_normal = torch.distributions.Independent(torch.distributions.Normal(mu, scale2), 1)

# Define diagonal normal with MultivariateNormal
cov2 = torch.diag_embed(scale2**2)
diag_normal_mv2 = torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=cov2)

torch.distributions.kl_divergence(diag_normal_mv1, diag_normal_mv2)  # > 0.9345
torch.distributions.kl_divergence(diag_normal_mv1, diag_normal)  # > NotImplementedError
1 Like

Hi, I have a problem that has been bothering me for a long time.
I wanna implement the JS divergence for (Multivariate)Normal with diagonal covariance matrix.
could you give me some advice? Thanks

Did anyone figure out a solution to the original question here? Or are willing to share an implementation of the analytical formula?

I’ve the same problem.