For Categorical distributions below, which are quite close, distributions.kl.kl_divergence()
returns negative values.
Interestingly, different ways of initialisation, ie via probs or logits, give different results.
#!/usr/bin/env python3
import torch
from torch.distributions.utils import probs_to_logits
print('torch.__version__', torch.__version__)
# p distrib
prob_p = torch.tensor([0.2739, 0.2857, 0.2089, 0.2316])
p_initprob = torch.distributions.Categorical(probs=prob_p)
p_initlogit = torch.distributions.Categorical(logits=probs_to_logits(prob_p))
assert torch.allclose(p_initprob.probs, p_initlogit.probs)
assert torch.allclose(p_initprob.logits, p_initlogit.logits)
# q distrib
prop_q = torch.tensor([0.2739, 0.2857, 0.2088, 0.2315])
q_initprob = torch.distributions.Categorical(probs=prop_q)
q_initlogit = torch.distributions.Categorical(logits=probs_to_logits(prop_q))
assert torch.allclose(q_initprob.probs, q_initlogit.probs)
assert torch.allclose(q_initprob.logits, q_initlogit.logits)
# KLDiv: p || q
print('KLDiv: p || q')
print('kldiv_p_q_initprob', torch.distributions.kl.kl_divergence(p_initprob, q_initprob))
print('kldiv_p_q_initlogit', torch.distributions.kl.kl_divergence(p_initlogit, q_initlogit))
# KLDiv: q || p
print('KLDiv: q || p')
print('kldiv_q_p_initprob', torch.distributions.kl.kl_divergence(q_initprob, p_initprob))
print('kldiv_q_p_initlogit', torch.distributions.kl.kl_divergence(q_initlogit, p_initlogit))
Output:
torch.__version__ 1.1.0
KLDiv: p || q
kldiv_p_q_initprob tensor(-1.6520e-08)
kldiv_p_q_initlogit tensor(1.0270e-07)
KLDiv: q || p
kldiv_q_p_initprob tensor(6.7579e-08)
kldiv_q_p_initlogit tensor(-5.1627e-08)
Is this a known issue due numerical instability?
Any clue on how to address this?
Thank you.
Related issues: