Hi,

I want to use KL divergence as loss function between two multivariate Gaussians. Is the following right way to do it?

mu1 = torch.rand((B, D), requires_grad=True)

std1 = torch.rand((B, D), requires_grad=True)

p = torch.distributions.Normal(mu1, std1)

mu2 = torch.rand((B, D))

std2 = torch.rand((B, D))

q = torch.distributions.Normal(mu2, std2)loss = torch.distributions.kl_divergence(p, q).mean()

loss.backward()

…

My understanding is that torch.distributions.kl_divergence computes kl(p,q) like derivations in section 9 of this document.