KL divergence loss

Trying to implement KL divergence loss but got nan always.

p = torch.randn((100,100))
q = torch.randn((100,100))
kl_loss = torch.nn.KLDivLoss(size_average= False)(p.log(), q)
output = nan
p_soft = F.softmax( p )
q_soft = F.softmax( q )
kl_loss = torch.nn.KLDivLoss(size_average= False)(p_soft.log(), q_soft)
output = 96.7017
Do we have to pass the distributions (p, q) through softmax function always?

According to the docs:

As with NLLLoss , the input given is expected to contain log-probabilities and is not restricted to a 2D Tensor. The targets are given as probabilities (i.e. without taking the logarithm).

your code snippet looks alright. I would recommend to use log_softmax instead of softmax().log(), as the former approach is numerically more stable.

1 Like

As i guess, KL divergence is supposed to return a numerical value representing the distance between 2 probability distributions in feature space. However, this is what I got using log softmax…

>>> k=torch.rand(256)
>>> k1=k.clone()
>>> F.kl_div(F.log_softmax(k),k1,reduction="none").mean()

On the other hand when I use simple log, the answer is zero, which is expected. Can you let me know what i should use when comparing 2 layers in pytorch with KL divergence?

The target should be given as probabilities:

k = torch.rand(256)
k1 = k.clone()
F.kl_div(F.log_softmax(k, 0), F.softmax(k1, 0), reduction="none").mean()
> tensor(6.2333e-10)
1 Like

Thank you ! I was missing that out :sweat_smile:

Thank you @ptrblck
There is something I don’t understand.
Let’s assume that: shape(target) = shape(input) = (batch_size, N)

The log_softmax / softmax must be in dimension 0?
F.kl_div(F.log_softmax(logits, dim = 0), F.softmax(target, dim = 0), reduction="none").mean()

Or in dimension 1 ?
F.kl_div(F.log_softmax(logits, dim = 1), F.softmax(target, dim = 1), reduction="none").mean()

Personally I think in dimension 1 (N here is a bit like the number of classes for a classification)

You would apply the log_softmax in the class dimension, so usually in dim1.
Note that my example is not a really representative one and I was just reusing the posted code.

1 Like

The following issue is relevant for people using kl_div or its nn module, as its current behaviour is wrong