Trying to implement KL divergence loss but got nan always.
p = torch.randn((100,100))
q = torch.randn((100,100))
kl_loss = torch.nn.KLDivLoss(size_average= False)(p.log(), q)
output = nan
p_soft = F.softmax( p )
q_soft = F.softmax( q )
kl_loss = torch.nn.KLDivLoss(size_average= False)(p_soft.log(), q_soft)
output = 96.7017
Do we have to pass the distributions (p, q) through softmax function always?
According to the docs:
, the input given is expected to contain
log-probabilities and is not restricted to a 2D Tensor. The targets are given as probabilities (i.e. without taking the logarithm).
your code snippet looks alright. I would recommend to use
log_softmax instead of
softmax().log(), as the former approach is numerically more stable.
As i guess, KL divergence is supposed to return a numerical value representing the distance between 2 probability distributions in feature space. However, this is what I got using log softmax…
On the other hand when I use simple log, the answer is zero, which is expected. Can you let me know what i should use when comparing 2 layers in pytorch with KL divergence?
The target should be given as probabilities:
k = torch.rand(256)
k1 = k.clone()
F.kl_div(F.log_softmax(k, 0), F.softmax(k1, 0), reduction="none").mean()
Thank you ! I was missing that out
k = torch.rand(256)
There is something I don’t understand.
Let’s assume that:
shape(target) = shape(input) = (batch_size, N)
The log_softmax / softmax must be in dimension 0?
F.kl_div(F.log_softmax(logits, dim = 0), F.softmax(target, dim = 0), reduction="none").mean()
Or in dimension 1 ?
F.kl_div(F.log_softmax(logits, dim = 1), F.softmax(target, dim = 1), reduction="none").mean()
Personally I think in dimension 1 (N here is a bit like the number of classes for a classification)
You would apply the
log_softmax in the class dimension, so usually in
Note that my example is not a really representative one and I was just reusing the posted code.
The following issue is relevant for people using
kl_div or its
nn module, as its current behaviour is wrong
08:24AM - 03 May 21 UTC
module: correctness (silent)
## 🐛 Bug
### Executive summary:
The inputs of `KLDivLoss` and `F.kl_div` are i