Hi Jim!
Depending on your use case, computing log (q) - log (p)
can be more stable
numerically than computing log (q / p)
. (They are the same mathematically,
but can differ numerically.)
Try using torch.mean (q * (q.log() - p.loq()))
as your KL divergence
implementation.
As an aside, p
would typically be a prediction, that is, the output of your model.
If so – for example, if the last layer of your model is a Linear
– the output of
your model – let’s call it y
– might naturally already be in log-space. That is,
your model might naturally be outputting y = p.log()
. If so, just use y
directly
(without ever explicitly converting it to p = y.exp()
). In this case, your KL
divergence would become torch.mean (q * (q.log() - y))
.
Best.
K. Frank