# KL Divergence implementation

Hello everyone,

why do these two implementations of the KL divergence give me different results, can anybody find the error? The difference is about 5%

note: I’ve commented out the ’ .sum() ’ because if I don’t, the code colours in the forum text editor change, can anybody suggest the reason?

Version 1:

``````    def q_z_dist(self, mu, logvar):
var = torch.exp(logvar)
std = torch.sqrt(var)
cov = torch.diag_embed(var)
return td.MultivariateNormal(mu, cov)

def p_z_dist(self, mu, logvar):
mu_prior = torch.zeros_like(mu)
var_prior = torch.ones_like(logvar)
cov_prior = torch.diag_embed(var_prior)
return td.MultivariateNormal(mu_prior, cov_prior)

def KL_divergence(self, mu, logvar):
p_dist = self.p_z_dist(mu, logvar)
q_dist = self.q_z_dist(mu, logvar)

KL = td.kl_divergence(p_dist, q_dist)
KL_batch = KL #.sum()
return KL_batch
``````

Version 2:

``````    def KL_divergence(self, mu, logvar):
KL = -0.5 * (1 + logvar - mu**2 -  torch.exp(logvar)).sum(dim = 1)
KL_batch = KL.sum()
return KL_batch
``````

So you have prior = p = N(0,1) and q = N(mu, diag(var))
It seems that you compute D_KL(p || q) in the first and the more common D_KL(q || p) in the second. KL Divergence is not symmetric, so these will differ.

Best regards

Thomas

1 Like