I am trying to compute the KL divergence between a mixture of Gaussians and a single Gaussian prior using Monte Carlo sampling. I have made this function, which should work - however, it relies on for loops and since I use it in a deep Bayesian neural network, this is not very desirable. I have tried intensively, but I simply cannot get this to work using vectorization. Any help would be greatly appreciated.
def kl_divergence_mixture(posterior_means, posterior_stds, prior_mean, prior_std, num_samples=1):
"""
Computes the KL divergence between a mixture of Gaussians posterior and a Gaussian prior using Monte Carlo sampling.
Args:<
posterior_means (torch.Tensor): Means of the Gaussian components (k, D).
posterior_stds (torch.Tensor): Standard deviations of the Gaussian components (k, D).
prior_mean (float): Mean of the Gaussian prior.
prior_std (float): Standard deviation of the Gaussian prior.
num_samples (int): Number of samples per component for Monte Carlo estimation.
Returns:
torch.Tensor: The estimated KL divergence.
"""
k, D = posterior_means.shape
pi = torch.ones(k) / k
posterior_dist = dist.Normal(posterior_means, posterior_stds)
prior_dist = dist.Normal(prior_mean, prior_std)
q_dists = [dist.Normal(mu_k, sigma_k) for mu_k, sigma_k in zip(posterior_means , posterior_stds)]
samples = posterior_dist.rsample()
log_q_w_components = torch.stack([torch.log(pi[k]) + q_dist.log_prob(samples) for k, q_dist in enumerate(q_dists)])
log_q_w_sum = torch.logsumexp(log_q_w_components, dim=0)
log_p_w = prior_dist.log_prob(samples)
f_w = log_q_w_sum - log_p_w
kl_div = f_w.mean(dim=0)
return kl_div.sum()
Your issue is that you need to apply .log_prob() of all of your posterior_dists to all of your samples and your loops that work with q_dists do this. However, you can avoid the loops if you duplicate the
values in samples using .expand() and let posterior_dist.log_prob()
apply itself to those duplicated samples.
Here is a script with a loop-free version (together with your original
version):
import torch
print (torch.__version__)
_ = torch.manual_seed (2024)
import torch.distributions as dist
# original version with loops (list comprehensions)
def kl_divergence_mixture(posterior_means, posterior_stds, prior_mean, prior_std, num_samples=1):
"""
Computes the KL divergence between a mixture of Gaussians posterior and a Gaussian prior using Monte Carlo sampling.
Args:<
posterior_means (torch.Tensor): Means of the Gaussian components (k, D).
posterior_stds (torch.Tensor): Standard deviations of the Gaussian components (k, D).
prior_mean (float): Mean of the Gaussian prior.
prior_std (float): Standard deviation of the Gaussian prior.
num_samples (int): Number of samples per component for Monte Carlo estimation.
Returns:
torch.Tensor: The estimated KL divergence.
"""
k, D = posterior_means.shape
pi = torch.ones(k) / k
posterior_dist = dist.Normal(posterior_means, posterior_stds)
prior_dist = dist.Normal(prior_mean, prior_std)
q_dists = [dist.Normal(mu_k, sigma_k) for mu_k, sigma_k in zip(posterior_means , posterior_stds)]
samples = posterior_dist.rsample()
log_q_w_components = torch.stack([torch.log(pi[k]) + q_dist.log_prob(samples) for k, q_dist in enumerate(q_dists)])
log_q_w_sum = torch.logsumexp(log_q_w_components, dim=0)
log_p_w = prior_dist.log_prob(samples)
f_w = log_q_w_sum - log_p_w
kl_div = f_w.mean(dim=0)
return kl_div.sum()
# loop-free version
def kl_divergence_mixtureB (posterior_means, posterior_stds, prior_mean, prior_std):
k, D = posterior_means.shape
log_pi = torch.tensor ([1.0 / k]).log()
posterior_dist = dist.Normal (posterior_means, posterior_stds)
prior_dist = dist.Normal (prior_mean, prior_std)
samples = posterior_dist.rsample()
# this log_q_w_components differs from the original log_q_w_components by a .permute (1, 0, 2)
log_q_w_components = log_pi + posterior_dist.log_prob (samples.unsqueeze (1).expand (k, k, D))
# dim = 1 (instead of the original dim = 0) compensates for difference in log_q_w_components
log_q_w_sum = torch.logsumexp (log_q_w_components, dim = 1)
log_p_w = prior_dist.log_prob (samples)
f_w = log_q_w_sum - log_p_w
kl_div = f_w.mean (dim = 0)
return kl_div.sum()
k = 10
D = 5
means = torch.randn (k, D)
stds = torch.rand (k, D)
mean = -0.7
std = 1.1
_ = torch.manual_seed (2024) # ensure same values of posterior_dist.rsample() in both versions
kl_div = kl_divergence_mixture (means, stds, mean, std)
_ = torch.manual_seed (2024) # ensure same values of posterior_dist.rsample() in both versions
kl_divB = kl_divergence_mixtureB (means, stds, mean, std)
print ('kl_div:', kl_div, ' kl_divB:', kl_divB)
print ('torch.equal (kl_div, kl_divB):', torch.equal (kl_div, kl_divB))