When I use torch.nn.functional.kl_div(), I notice that while the reduced mean of result is positive, some values in the unreduced result are negative. I was wondering if it is the correct behavior, or I made some mistake in my codes?
Thank you very much!
I just tried a toy example and got the same result, so I guess it’s the correct behavior of kl_div…
I’m still a little confused about how the mean of unreduced result is guaranteed to be non-negative, though I know KL divergence should be non-negative. Any explanation is appreciated!
You are correct that the KL divergence is non-negative, and you are
also correct that the individual terms returned by pytorch’s “unreduced”
kl_div() can be negative.
The reason is that, by definition, the KL divergence is taken between
two probability distributions, not between two arbitrary sets of numbers.
If you apply the KL-divergence formula to arbitrary numbers, the result
is not guaranteed to be non-negative.
b contain three elements. Let’s therefore say that you are
working with a three-class problem.
a is a log-probability distribution
over the three classes, and satisfies
a.exp().sum() = 1.0, that is,
the three probabilities sum to one, as is required of a probability
distribution. Similarly (after applying
b is a probability
distribution over the three classes and sums to one.
The result of the unreduced
kl_div() is simply the three terms in
the KL-divergence formula before they are summed over the three
classes. Not being summed yet, the individual terms “can’t know”
that they come from proper probability distributions that sum to one,
and need not be non-negative. Only when summed together to form
the full KL divergence do they produce a result that is guaranteed to
It is a theorem – and, in the discrete case, the result of a not particularly
difficult calculation – that the KL divergence is non-negative, but the
theorem (and calculation) rely on the KL-divergence formula being
applied to legitimate probability distributions.
Thank you! Your answer perfectly resolves my confusion.