According to some constraint within our project, I need to rewrite KL divergence computation with basic PyTorch operations. Here is my implementation of KL divergence. torch.mean(q*torch.log(q/p)
I compared the output values of this implement with the output values from the following: nn.KLDivLoss()(p.log(), q)
And they generate the same values for my limited test set.
When I kept using
nn.KLDivLoss()(p.log(), q)
, everything was fine. However, when I replaced it with
torch.mean(q*torch.log(q/p)
Some part of my model started generating NAN after several training epochs.
Could someone provide some guidance how this could happen and how to fix it?
Depending on your use case, computing log (q) - log (p) can be more stable
numerically than computing log (q / p). (They are the same mathematically,
but can differ numerically.)
Try using torch.mean (q * (q.log() - p.loq()))as your KL divergence
implementation.
As an aside, p would typically be a prediction, that is, the output of your model.
If so β for example, if the last layer of your model is a Linear β the output of
your model β letβs call it y β might naturally already be in log-space. That is,
your model might naturally be outputting y = p.log(). If so, just use y directly
(without ever explicitly converting it to p = y.exp()). In this case, your KL
divergence would become torch.mean (q * (q.log() - y)).