I use torch.distributions
for variational inference and thus use torch.distributions.kl_divergence
to compute analytical KL divergences. However, there is no registered KL divergence for normal distributions with diagonal covariance when they are defined by wrapping Normal
with Independent
. The code example below shows this in detail.
The KL divergence can be correctly computed if we use the MultivariateNormal
class to instantiate the covariance matrix, but this is rather inefficient both in terms of memory and compute time when the covariance matrix is diagonal.
Is there any other way to go here or do I simply have to implement the analytical formula myself?
import torch
# Define diagonal normal with MultivariateNormal (to compute KL wrt.)
mu = torch.Tensor([0, 0])
scale1 = torch.Tensor([1, 1])
cov1 = torch.diag_embed(scale1)
diag_normal_mv1 = torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=cov1)
# Define diagonal normal with Independent
scale2 = torch.Tensor([0.5, 1.5])
diag_normal = torch.distributions.Independent(torch.distributions.Normal(mu, scale2), 1)
# Define diagonal normal with MultivariateNormal
cov2 = torch.diag_embed(scale2**2)
diag_normal_mv2 = torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=cov2)
torch.distributions.kl_divergence(diag_normal_mv1, diag_normal_mv2) # > 0.9345
torch.distributions.kl_divergence(diag_normal_mv1, diag_normal) # > NotImplementedError