torch.distributions for variational inference and thus use
torch.distributions.kl_divergence to compute analytical KL divergences. However, there is no registered KL divergence for normal distributions with diagonal covariance when they are defined by wrapping
Independent. The code example below shows this in detail.
The KL divergence can be correctly computed if we use the
MultivariateNormal class to instantiate the covariance matrix, but this is rather inefficient both in terms of memory and compute time when the covariance matrix is diagonal.
Is there any other way to go here or do I simply have to implement the analytical formula myself?
import torch # Define diagonal normal with MultivariateNormal (to compute KL wrt.) mu = torch.Tensor([0, 0]) scale1 = torch.Tensor([1, 1]) cov1 = torch.diag_embed(scale1) diag_normal_mv1 = torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=cov1) # Define diagonal normal with Independent scale2 = torch.Tensor([0.5, 1.5]) diag_normal = torch.distributions.Independent(torch.distributions.Normal(mu, scale2), 1) # Define diagonal normal with MultivariateNormal cov2 = torch.diag_embed(scale2**2) diag_normal_mv2 = torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=cov2) torch.distributions.kl_divergence(diag_normal_mv1, diag_normal_mv2) # > 0.9345 torch.distributions.kl_divergence(diag_normal_mv1, diag_normal) # > NotImplementedError