Hi,
I would like to compute the KL divergence between 2 RelaxedOneHotCategorial distributions. I have the feeling I’m doing something wrong as the KL divergence is super high.
import torch
from torch.distributions import RelaxedOneHotCategorical
p_m = RelaxedOneHotCategorical(torch.tensor([2.2]), probs=torch.tensor([0.1, 0.2, 0.3, 0.4]))
batch_param_obtained_from_a_nn = torch.rand(2, 4)
q_m = RelaxedOneHotCategorical(torch.tensor([5.4]), logits=batch_param_obtained_from_a_nn)
z = q_m.rsample()
kl = - torch.mean(q_m.log_prob(z).exp() * (q_m.log_prob(z) - p_m.log_prob(z)))
z
tensor([[0.2671, 0.2973, 0.2144, 0.2212],
[0.2431, 0.2550, 0.3064, 0.1954]])
kl
tensor(-766.7020)
Did I miss something trivial ?
Thank you in advance !