KL Divergence implementation

Hello everyone,

why do these two implementations of the KL divergence give me different results, can anybody find the error? The difference is about 5%

note: I’ve commented out the ’ .sum() ’ because if I don’t, the code colours in the forum text editor change, can anybody suggest the reason?

Version 1:

    def q_z_dist(self, mu, logvar):
        var = torch.exp(logvar)
        std = torch.sqrt(var)
        cov = torch.diag_embed(var)
        return td.MultivariateNormal(mu, cov)
    
    def p_z_dist(self, mu, logvar):
        mu_prior = torch.zeros_like(mu)
        var_prior = torch.ones_like(logvar)
        cov_prior = torch.diag_embed(var_prior)
        return td.MultivariateNormal(mu_prior, cov_prior)
    
    def KL_divergence(self, mu, logvar):
        p_dist = self.p_z_dist(mu, logvar)
        q_dist = self.q_z_dist(mu, logvar)

        KL = td.kl_divergence(p_dist, q_dist)
        KL_batch = KL #.sum()
        return KL_batch

Version 2:

    def KL_divergence(self, mu, logvar):
        KL = -0.5 * (1 + logvar - mu**2 -  torch.exp(logvar)).sum(dim = 1)
        KL_batch = KL.sum()
        return KL_batch

So you have prior = p = N(0,1) and q = N(mu, diag(var))
It seems that you compute D_KL(p || q) in the first and the more common D_KL(q || p) in the second. KL Divergence is not symmetric, so these will differ.

Best regards

Thomas

1 Like