I recently discovered the distributions package in PT. I’d like to estimate the KL divergence which I think I can do by:
# KL div
q = torch.distributions.normal.Normal(z_mu, std)
p = torch.distributions.normal.Normal(loc=torch.zeros_like(z_mu), scale=torch.ones_like(std))
qz = q.log_prob(z)
pz = p.log_prob(z)
inside = (qz - pz)
kl_loss = torch.mean(inside)
But, the results are really bad compared with doing it analytically.
# kl_loss = 0.5 * torch.sum(-torch.log(z_var) + z_mu**2 - 1.0 - z_var)
kl_loss = 0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1.0 - z_var)
Can anyone spot what I’m doing wrong in the MC version?