Use KL divergence as loss between two multivariate Gaussians


I want to use KL divergence as loss function between two multivariate Gaussians. Is the following right way to do it?

mu1 = torch.rand((B, D), requires_grad=True)
std1 = torch.rand((B, D), requires_grad=True)
p = torch.distributions.Normal(mu1, std1)
mu2 = torch.rand((B, D))
std2 = torch.rand((B, D))
q = torch.distributions.Normal(mu2, std2)

loss = torch.distributions.kl_divergence(p, q).mean()

My understanding is that torch.distributions.kl_divergence computes kl(p,q) like derivations in section 9 of this document.


any update on this question?


You are right. When you are using distributions from torch.distribution package, you are doing fine by using torch.distribution.kl_divergence. But if you want to get kl by passing two tensors obtain elsewhere, you can do following approach:

@Rojin I have posted this on your thread actually.

This is the kl between two arbitrary layers.
Just be aware that the input a must should contain log-probabilities and the target b should contain probability.

By the way, PyTorch use this approach:

Good luck

I test the kl_divergence as follows, the tensors’ shapes are [batch_size,n], where n is my gaussian distribution dimension (i.e. 3 here),

        mu1 = torch.Tensor([[1., 2., 3.],
						[2., 3., 4.]])
        var_1 = torch.Tensor([[1., 1., 1.],
						  [4., 9., 16.]])

	mu2 = torch.Tensor([[1., 3., 4.],
						[2., 3., 4.]])
	var_2 = torch.Tensor([[1., 4., 9.],
						  [4., 9., 16.]])

	p = torch.distributions.Normal(mu1, var_1)
	q = torch.distributions.Normal(mu2, var_2)
	kl_loss = torch.distributions.kl_divergence(p, q)


The outpur is :

tensor([[0.6542, 0.9488, 1.7096],
        [0.0000, 0.0000, 0.0000]])

Why 3 numbers there for each instance? In my thought, the case above contains two pairs (p,q) , each pair should have a scalar KL loss.

Any ideas? Thank you

Ooh, I got it !

In my above case, I assume the distribution to be multivariate gaussian, I should use torch.distributions.MultivariateNormal(mu1, var_1) instead, which meets my hope. :sweat_smile:


can I ask a question that why you set requires_grad = True in mu1&std1 while mu2&std2 did not?
Is there something we should watch out?

I run into a problem where MultivariateNormal does not allow [batch ,n] shape because batch and n has to be equal, Do you manage to solve this?

Does the KL calculation correct?
log(a/b) = log(a) - b???