Hi,
I want to use KL divergence as loss function between two multivariate Gaussians. Is the following right way to do it?
mu1 = torch.rand((B, D), requires_grad=True)
std1 = torch.rand((B, D), requires_grad=True)
p = torch.distributions.Normal(mu1, std1)
mu2 = torch.rand((B, D))
std2 = torch.rand((B, D))
q = torch.distributions.Normal(mu2, std2)loss = torch.distributions.kl_divergence(p, q).mean()
loss.backward()
…
My understanding is that torch.distributions.kl_divergence computes kl(p,q) like derivations in section 9 of this document.