KL divergence loss pytorch implementation

According to the theory kl divergence is the difference between cross entropy (of inputs and targets) and the entropy (of targets).

I compared the kl div loss implementation in pytorch against the custom implementation based on the above theory. But the results are not the same, I am not sure why there is a difference.

import torch

m = nn.Softmax(dim=1)
i_tensor_before_softmax = torch.Tensor([0.1, 0.2, 0.4, 0.3])
i_tensor = m(p_tensor.view(-1,4))
o_tensor_before_softmax = torch.Tensor([0.7, 0.1, 0.1, 0.1])
o_tensor = m(q_tensor.view(-1,4))

import torch.nn.functional as F
kl_loss = nn.KLDivLoss(reduction=“batchmean”)
loss = torch.nn.CrossEntropyLoss()
kl_output = kl_loss(input=F.log_softmax(i_tensor_before_softmax, dim=-1), target=o_tensor)
cross_ent = loss(input=i_tensor_before_softmax.view(-1,4), target=o_tensor)
ent = loss(input=o_tensor_before_softmax.view(-1,4), target=o_tensor)
kl_output_using_ce = cross_ent - ent
print(kl_output, kl_output_using_ce)

output: (tensor(0.0179), tensor(0.0716))

Help me out in identifying the issue :smile:
Thanks,
Krishnan

Hi Krishnan!

The short story is that your results differ by a factor of four because
KLDivLoss takes the mean over the class dimension, while
CrossEntropyLoss sums.

This is correct.

For what are presumably historical reasons (CrossEntropyLoss did not
used to support probabilistic targets.), KLDivLoss computes what it calls
the “pointwise” KL-divergence, while pytorch does not use a notion of
“pointwise” cross entropy.

Let’s call the first dimension of your tensors the “batch” dimension and
the second the “class” dimension. So in your example your tensors have
shape [nBatch = 1, nClass = 4].

KLDivLoss’s default reduction of mean takes the mean over all
dimensions of the pointwise-KL-divergence tensor, so it takes the mean
over the nClass dimension, dividing by four (and then takes the trivial
mean over the nBatch = 1 dimension).

CrossEntropyLoss, by definition, sums over the nClass dimension, not
dividing by four (and then, with its default reduction = mean, takes the
trivial mean over the nBatch dimension).

Therefore your two results differ by a factor of four.

Consider:

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> m = torch.nn.Softmax (dim=1)
>>> i_tensor_before_softmax = torch.Tensor ([0.1, 0.2, 0.4, 0.3])
>>> i_tensor = m (i_tensor_before_softmax.view (-1,4))
>>> o_tensor_before_softmax = torch.Tensor ([0.7, 0.1, 0.1, 0.1])
>>> o_tensor = m (o_tensor_before_softmax.view (-1,4))
>>>
>>> import torch.nn.functional as F
>>> kl_loss = torch.nn.KLDivLoss (reduction = 'batchmean')
>>> loss = torch.nn.CrossEntropyLoss()
>>> kl_output = kl_loss (input = F.log_softmax (i_tensor_before_softmax, dim=-1), target = o_tensor)
>>> cross_ent = loss (input = i_tensor_before_softmax.view (-1,4), target = o_tensor)
>>> ent = loss (input=o_tensor_before_softmax.view (-1,4), target = o_tensor)
>>> kl_output_using_ce = cross_ent - ent
>>> print (kl_output, kl_output_using_ce)
tensor(0.0179) tensor(0.0716)
>>>
>>> kl_loss = torch.nn.KLDivLoss (reduction = 'none')
>>> loss = torch.nn.CrossEntropyLoss (reduction = 'none')
>>> kl_output_unr = kl_loss (input = F.log_softmax (i_tensor_before_softmax, dim=-1), target = o_tensor)
>>> cross_ent_unr = loss (input = i_tensor_before_softmax.view (-1,4), target = o_tensor)
>>> ent_unr = loss (input=o_tensor_before_softmax.view (-1,4), target = o_tensor)
>>> kl_output_using_ce_unr = cross_ent_unr - ent_unr
>>> print (kl_output_unr)
tensor([[ 0.2151, -0.0271, -0.0686, -0.0478]])
>>> print (kl_output_using_ce_unr)
tensor([0.0716])
>>> print (kl_output_unr.mean(), kl_output_unr.sum(), kl_output_using_ce_unr)
tensor(0.0179) tensor(0.0716) tensor([0.0716])

Best.

K. Frank

2 Likes