Compute KL divergence between mixture of Gaussians and single Gaussian using vectorization

Hi,

I am trying to compute the KL divergence between a mixture of Gaussians and a single Gaussian prior using Monte Carlo sampling. I have made this function, which should work - however, it relies on for loops and since I use it in a deep Bayesian neural network, this is not very desirable. I have tried intensively, but I simply cannot get this to work using vectorization. Any help would be greatly appreciated.

def kl_divergence_mixture(posterior_means, posterior_stds, prior_mean, prior_std, num_samples=1):
    """
    Computes the KL divergence between a mixture of Gaussians posterior and a Gaussian prior using Monte Carlo sampling.
    
    Args:<
        posterior_means (torch.Tensor): Means of the Gaussian components (k, D).
        posterior_stds (torch.Tensor): Standard deviations of the Gaussian components (k, D).
        prior_mean (float): Mean of the Gaussian prior.
        prior_std (float): Standard deviation of the Gaussian prior.
        num_samples (int): Number of samples per component for Monte Carlo estimation.
        
    Returns:
        torch.Tensor: The estimated KL divergence.
    """
    k, D = posterior_means.shape

    pi = torch.ones(k) / k

    posterior_dist = dist.Normal(posterior_means, posterior_stds)
    prior_dist = dist.Normal(prior_mean, prior_std)   
    q_dists = [dist.Normal(mu_k, sigma_k) for mu_k, sigma_k in zip(posterior_means , posterior_stds)]

    samples = posterior_dist.rsample()

    log_q_w_components = torch.stack([torch.log(pi[k]) + q_dist.log_prob(samples) for k, q_dist in enumerate(q_dists)])
    log_q_w_sum = torch.logsumexp(log_q_w_components, dim=0)
    log_p_w = prior_dist.log_prob(samples)

    f_w = log_q_w_sum - log_p_w

    kl_div = f_w.mean(dim=0)

    return kl_div.sum()

Hi Jonathan!

Your issue is that you need to apply .log_prob() of all of your
posterior_dists to all of your samples and your loops that work with
q_dists do this. However, you can avoid the loops if you duplicate the
values in samples using .expand() and let posterior_dist.log_prob()
apply itself to those duplicated samples.

Here is a script with a loop-free version (together with your original
version):

import torch
print (torch.__version__)

_ = torch.manual_seed (2024)

import torch.distributions as dist

# original version with loops (list comprehensions)
def kl_divergence_mixture(posterior_means, posterior_stds, prior_mean, prior_std, num_samples=1):
    """
    Computes the KL divergence between a mixture of Gaussians posterior and a Gaussian prior using Monte Carlo sampling.
    
    Args:<
        posterior_means (torch.Tensor): Means of the Gaussian components (k, D).
        posterior_stds (torch.Tensor): Standard deviations of the Gaussian components (k, D).
        prior_mean (float): Mean of the Gaussian prior.
        prior_std (float): Standard deviation of the Gaussian prior.
        num_samples (int): Number of samples per component for Monte Carlo estimation.
        
    Returns:
        torch.Tensor: The estimated KL divergence.
    """
    k, D = posterior_means.shape
    
    pi = torch.ones(k) / k
    
    posterior_dist = dist.Normal(posterior_means, posterior_stds)
    prior_dist = dist.Normal(prior_mean, prior_std)   
    q_dists = [dist.Normal(mu_k, sigma_k) for mu_k, sigma_k in zip(posterior_means , posterior_stds)]
    
    samples = posterior_dist.rsample()
    
    log_q_w_components = torch.stack([torch.log(pi[k]) + q_dist.log_prob(samples) for k, q_dist in enumerate(q_dists)])
    log_q_w_sum = torch.logsumexp(log_q_w_components, dim=0)
    log_p_w = prior_dist.log_prob(samples)
    
    f_w = log_q_w_sum - log_p_w
    
    kl_div = f_w.mean(dim=0)
    
    return kl_div.sum()

# loop-free version
def kl_divergence_mixtureB (posterior_means, posterior_stds, prior_mean, prior_std):
    k, D = posterior_means.shape
    log_pi = torch.tensor ([1.0 / k]).log()
    
    posterior_dist = dist.Normal (posterior_means, posterior_stds)
    prior_dist = dist.Normal (prior_mean, prior_std)
    samples = posterior_dist.rsample()
    
    # this log_q_w_components differs from the original log_q_w_components by a .permute (1, 0, 2)
    log_q_w_components = log_pi + posterior_dist.log_prob (samples.unsqueeze (1).expand (k, k, D))

    # dim = 1 (instead of the original dim = 0) compensates for difference in log_q_w_components
    log_q_w_sum = torch.logsumexp (log_q_w_components, dim = 1)
    
    log_p_w = prior_dist.log_prob (samples)
    f_w = log_q_w_sum - log_p_w
    kl_div = f_w.mean (dim = 0)
    return kl_div.sum()


k = 10
D = 5

means = torch.randn (k, D)
stds = torch.rand (k, D)

mean = -0.7
std = 1.1

_ = torch.manual_seed (2024)   # ensure same values of posterior_dist.rsample() in both versions
kl_div  = kl_divergence_mixture  (means, stds, mean, std)

_ = torch.manual_seed (2024)   # ensure same values of posterior_dist.rsample() in both versions
kl_divB = kl_divergence_mixtureB (means, stds, mean, std)

print ('kl_div:', kl_div, '  kl_divB:', kl_divB)
print ('torch.equal (kl_div, kl_divB):', torch.equal (kl_div, kl_divB))

And here is its output:

2.3.1
kl_div: tensor(4.0859)   kl_divB: tensor(4.0859)
torch.equal (kl_div, kl_divB): True

(Note that your num_samples argument is unused in your function, so
I didn’t include it in the loop-free version.)

Best.

K. Frank

Hi K. Frank,

Thanks a bunch for your assistance - that seems to do the trick.