KL loss and log_softmax

Hi, as mentioned in the docs, the student logits need to pass through a log_softmax, my question is why? log softmax doesn’t give us a distribution, is it because of the way pytorch implements kl loss?

Thanks!

log_softmax gives you logarithms of a distribution…

Expanding DL(P_target || P_prediction) = P_target * (log P_target - log P_prediction), we see that P_prediction only appears as a logarithm. Given that taking the exponential and then converting back is numerically somewhat unstable, it makes sense to avoid this step. That’s why KLDivLoss takes P_target and log P_prediction.

Best regards

Thomas

2 Likes