I thought torch.nn.functional.kl_div should compute KL divergence in Kullback–Leibler divergence - Wikipedia (the same as scipy.stats.entropy and tf.keras.losses.KLDivergence), but I cannot get the same results from a simple example. Does anyone know why?
from scipy.stats import entropy
<tf.Tensor: shape=(), dtype=float32, numpy=0.08717668>
Those two above gives the same kl divergence value. But when I tried to use torch.nn.functional.kl_div, the result is not the same.
So if we look at the documentation for
kl_div we are instructed to See
KLDivLoss for details.
There it says:
the input given is expected to contain log-probabilities and is not restricted to a 2D Tensor. The targets are interpreted as probabilities by default, but could be considered as log-probabilities with
log_target set to
In other words, the first argument should be log probs.
Personally, I’d not use FloatTensor, it’s been superseded as the preferred way to create tensors longer than it had been the way to create them.
Thanks – that makes sense. I must have missed the log probability part, although it seems weird not giving an option to specify probabilities directly on input.
And thanks for pointing out the FloatTensor issue. I guess torch.tensor() is the recommended way now.
So the definition that it is log probs has historically grown, but as a rule of thumb, probabilities very close to 1 are difficult to work with (try evaluating
1-1e-20, and it’s a lot less for float vs. double). Now if you know that your probabilities are close to 1 but not close to 0, you can avoid this by using
1-p consistently, but it seems that using log-probs as the default representation of probabilities is a good idea in general.
Now this of distinctly practical relevance and crops up quite frequently
So to me, it makes a lot of sense to have log probs whenever actually doing something with it, but I would agree that it might be better to mention it explicitly also in the kl_div documentation (personally, I try to name my parameters
log_probs or so, but I guess that’s not trivial to change).