Hi Krishnan!
The short story is that your results differ by a factor of four because
KLDivLoss
takes the mean over the class dimension, while
CrossEntropyLoss
sums.
This is correct.
For what are presumably historical reasons (CrossEntropyLoss
did not
used to support probabilistic targets.), KLDivLoss
computes what it calls
the “pointwise” KL-divergence, while pytorch does not use a notion of
“pointwise” cross entropy.
Let’s call the first dimension of your tensors the “batch” dimension and
the second the “class” dimension. So in your example your tensors have
shape [nBatch = 1, nClass = 4]
.
KLDivLoss
’s default reduction
of mean
takes the mean over all
dimensions of the pointwise-KL-divergence tensor, so it takes the mean
over the nClass
dimension, dividing by four (and then takes the trivial
mean over the nBatch = 1
dimension).
CrossEntropyLoss
, by definition, sums over the nClass
dimension, not
dividing by four (and then, with its default reduction = mean
, takes the
trivial mean over the nBatch
dimension).
Therefore your two results differ by a factor of four.
Consider:
>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> m = torch.nn.Softmax (dim=1)
>>> i_tensor_before_softmax = torch.Tensor ([0.1, 0.2, 0.4, 0.3])
>>> i_tensor = m (i_tensor_before_softmax.view (-1,4))
>>> o_tensor_before_softmax = torch.Tensor ([0.7, 0.1, 0.1, 0.1])
>>> o_tensor = m (o_tensor_before_softmax.view (-1,4))
>>>
>>> import torch.nn.functional as F
>>> kl_loss = torch.nn.KLDivLoss (reduction = 'batchmean')
>>> loss = torch.nn.CrossEntropyLoss()
>>> kl_output = kl_loss (input = F.log_softmax (i_tensor_before_softmax, dim=-1), target = o_tensor)
>>> cross_ent = loss (input = i_tensor_before_softmax.view (-1,4), target = o_tensor)
>>> ent = loss (input=o_tensor_before_softmax.view (-1,4), target = o_tensor)
>>> kl_output_using_ce = cross_ent - ent
>>> print (kl_output, kl_output_using_ce)
tensor(0.0179) tensor(0.0716)
>>>
>>> kl_loss = torch.nn.KLDivLoss (reduction = 'none')
>>> loss = torch.nn.CrossEntropyLoss (reduction = 'none')
>>> kl_output_unr = kl_loss (input = F.log_softmax (i_tensor_before_softmax, dim=-1), target = o_tensor)
>>> cross_ent_unr = loss (input = i_tensor_before_softmax.view (-1,4), target = o_tensor)
>>> ent_unr = loss (input=o_tensor_before_softmax.view (-1,4), target = o_tensor)
>>> kl_output_using_ce_unr = cross_ent_unr - ent_unr
>>> print (kl_output_unr)
tensor([[ 0.2151, -0.0271, -0.0686, -0.0478]])
>>> print (kl_output_using_ce_unr)
tensor([0.0716])
>>> print (kl_output_unr.mean(), kl_output_unr.sum(), kl_output_using_ce_unr)
tensor(0.0179) tensor(0.0716) tensor([0.0716])
Best.
K. Frank