How to compute KL divergence for other distributions?

Hi,

I would like to compute the KL divergence between 2 RelaxedOneHotCategorial distributions. I have the feeling I’m doing something wrong as the KL divergence is super high.

import torch
from torch.distributions import RelaxedOneHotCategorical


p_m = RelaxedOneHotCategorical(torch.tensor([2.2]), probs=torch.tensor([0.1, 0.2, 0.3, 0.4]))
batch_param_obtained_from_a_nn = torch.rand(2, 4)
q_m = RelaxedOneHotCategorical(torch.tensor([5.4]), logits=batch_param_obtained_from_a_nn)

z = q_m.rsample()

kl = - torch.mean(q_m.log_prob(z).exp() * (q_m.log_prob(z) - p_m.log_prob(z)))

z                                                                                                                                                                                                                                                                       
tensor([[0.2671, 0.2973, 0.2144, 0.2212],
        [0.2431, 0.2550, 0.3064, 0.1954]])

kl                                                                                                                                                                                                                                                                      
tensor(-766.7020)

Did I miss something trivial ?

Thank you in advance !

Does anyone have a clue on this ?

It seems like your formula is correct but one thread addresses your issue:

It could be that you get wrong results because RelaxedOneHotCategorical is numerical instabil. The linked thread also provides a solution for that.

Thank you for your answer. I see from the code source of RelaxedOneHotCategorical, it uses ExpRelaxedOneHotCategorical, so it shouldn’t suffer from numerical instability ?

I’m reading the paper, appendix C, I’m not very comfortable. From what I understood, the best thing would be to use

So we should add log p_theta (x|Z) ?

I’ve seen that PyTorch uses ExpConcrete to avoid underflows as proposed in the Maddison (https://arxiv.org/pdf/1611.00712.pdf). But I have the filling it’s not enough.

Would it be also correct to do it that way ?

import torch
from torch.distributions import RelaxedOneHotCategorical

batch_param_obtained_from_a_nn = torch.rand(2, 4)
q_m = RelaxedOneHotCategorical(torch.tensor([5.4]), logits=batch_param_obtained_from_a_nn)
z = q_m.rsample()

# ...

p_m = torch.tensor([0.1, 0.2, 0.3, 0.4])

kl = - torch.mean(torch.sum(torch.log(z + 1e-20) * (torch.log(z + 1e-20) - torch.log(p_m + 1e-20)), dim=1))
z                                                                                                                                                                                                                                                  
tensor([[0.3518, 0.2110, 0.2363, 0.2010],
        [0.2213, 0.2337, 0.2818, 0.2632]])

kl                                                                                                                                                                                                                                                 
tensor(0.3673)
1 Like