# Use KL divergence as loss between two multivariate Gaussians

Hi,

I want to use KL divergence as loss function between two multivariate Gaussians. Is the following right way to do it?

p = torch.distributions.Normal(mu1, std1)
mu2 = torch.rand((B, D))
std2 = torch.rand((B, D))
q = torch.distributions.Normal(mu2, std2)

loss = torch.distributions.kl_divergence(p, q).mean()
loss.backward()

My understanding is that torch.distributions.kl_divergence computes kl(p,q) like derivations in section 9 of this document.

2 Likes

any update on this question?

Hi,

You are right. When you are using distributions from `torch.distribution` package, you are doing fine by using `torch.distribution.kl_divergence`. But if you want to get `kl` by passing two tensors obtain elsewhere, you can do following approach:

This is the `kl` between two arbitrary layers.
Just be aware that the input `a` must should contain `log-probabilities` and the target `b` should contain `probability`.

https://pytorch.org/docs/stable/nn.functional.html?highlight=kl_div#kl-div

By the way, PyTorch use this approach:

https://pytorch.org/docs/stable/distributions.html?highlight=kl_div#torch.distributions.kl.kl_divergence

Good luck
Nik

Hi，
I test the `kl_divergence` as follows, the tensors’ shapes are `[batch_size,n]`, where `n` is my gaussian distribution dimension (i.e. 3 here),

``````        mu1 = torch.Tensor([[1., 2., 3.],
[2., 3., 4.]])
var_1 = torch.Tensor([[1., 1., 1.],
[4., 9., 16.]])

mu2 = torch.Tensor([[1., 3., 4.],
[2., 3., 4.]])
var_2 = torch.Tensor([[1., 4., 9.],
[4., 9., 16.]])

p = torch.distributions.Normal(mu1, var_1)
q = torch.distributions.Normal(mu2, var_2)
kl_loss = torch.distributions.kl_divergence(p, q)

print(kl_loss)
``````

The outpur is :

``````tensor([[0.6542, 0.9488, 1.7096],
[0.0000, 0.0000, 0.0000]])
``````

Why 3 numbers there for each instance? In my thought, the case above contains two pairs `(p,q)` , each pair should have a scalar KL loss.

Any ideas? Thank you

Ooh, I got it !

In my above case, I assume the distribution to be multivariate gaussian, I should use `torch.distributions.MultivariateNormal(mu1, var_1)` instead, which meets my hope.

2 Likes

Hi,
can I ask a question that why you set requires_grad = True in mu1&std1 while mu2&std2 did not?
Is there something we should watch out?

Hi,
I run into a problem where MultivariateNormal does not allow [batch ,n] shape because batch and n has to be equal, Do you manage to solve this?

Does the KL calculation correct?
log(a/b) = log(a) - b???