I recently discovered the distributions package in PT. I’d like to estimate the KL divergence which I think I can do by:
# KL div q = torch.distributions.normal.Normal(z_mu, std) p = torch.distributions.normal.Normal(loc=torch.zeros_like(z_mu), scale=torch.ones_like(std)) qz = q.log_prob(z) pz = p.log_prob(z) inside = (qz - pz) kl_loss = torch.mean(inside)
But, the results are really bad compared with doing it analytically.
# kl_loss = 0.5 * torch.sum(-torch.log(z_var) + z_mu**2 - 1.0 - z_var) kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)
Can anyone spot what I’m doing wrong in the MC version?