The implementation of torch.nn.functional.kl_div()

I am using torch.nn.functional.kl_div() to calculate the KL divergence between the outputs of two networks. However, it seems the output of F.kl_div() is not consistent with the definition.

For example, let assume the normalized pred = torch.Tensor([[0.2, 0.8]]), and target = torch.Tensor([[0.1, 0.9]]).

Then the output of F.kl_div() would be:
F.kl_div(pred, target, reduction=‘sum’, log_target=False) —> -1.0651
or
F.kl_div(pred, target, reduction=‘sum’, log_target=True) —> 0.1354

However, if I calculate the KL divergence according to the definition:
(pred * torch.log(pred/target)).sum() —> 0.0444

Does anyone know what is the reason for the difference (the torch version is 1.8)?

Any discussions ???

I get the same outputs using the definition from the docs:

pred = torch.tensor([[0.2, 0.8]])
target = torch.tensor([[0.1, 0.9]])
F.kl_div(pred, target, reduction='sum', log_target=False)
> tensor(-1.0651)
torch.sum(target * (torch.log(target) - pred))
> tensor(-1.0651)

F.kl_div(pred, target, reduction='sum', log_target=True)
> tensor(0.1354)
torch.sum(torch.exp(target) * (target - pred))
> tensor(0.1354)

Hi, ptrblck,

Thanks for your reply. I think I understand:

  1. pred is already passed through log before entering F.kl_div(), that’s why there is no torch.log() for pred.
  2. F.kl_div(pred, target) is KL(target | pred) instead of KL(pred | target), is it correct?