CCE loss mean reduction differs between sparse and dense labels

Hello,
I am currently playing around with using class weights in the cce loss and noticed that I get different results from using the reduction “mean” between using one-hot encoded targets vs a sparse label tensor.
It seems to me that the class weights are not used in case of mean reduction and label targets.
Please check the code snippet below where I would expect bot “mean tensor” results to be the same, which is not the case
Do I misunderstood something here?

import torch

weight=torch.tensor([0.2,0.8])
softmax = torch.nn.Softmax(-1)


cce_none = torch.nn.CrossEntropyLoss(weight=weight,reduction="none")
cce_mean = torch.nn.CrossEntropyLoss(weight=weight,reduction="mean")

input = torch.tensor([[.1,.6],[.3,.4]])
target = torch.tensor([[1.0,0.0],[1.0,0.0]])
sparse_target = torch.tensor([0,0])


print(input.shape)
print(target.shape)


print("\n log ")
# print(torch.log(input))
soft = softmax(input)
print(soft)
print("softlog",torch.log(soft))
print("softlog*weight",torch.log(soft)*weight)

cce_loss = cce_none(input,target)
print("none",cce_loss)

cce_loss = cce_mean(input,target)
print("mean",cce_loss)


print("\n sparse")
print(sparse_target.shape)

Which gives me:

torch.Size([2, 2])
torch.Size([2, 2])

 log 
tensor([[0.3775, 0.6225],
        [0.4750, 0.5250]])
softlog tensor([[-0.9741, -0.4741],
        [-0.7444, -0.6444]])
softlog*weight tensor([[-0.1948, -0.3793],
        [-0.1489, -0.5155]])
none tensor([0.1948, 0.1489])
mean tensor(0.1718)

 sparse
torch.Size([2])
none tensor([0.1948, 0.1489])
mean tensor(0.8592)

Hi Joe!

This is not exactly true – both cases use the class weights, but use them
a little differently. (But your basic observation that the two approaches give
different results is correct.)

The missing piece is that the two approaches compute the weighted mean
in reduction = mean somewhat differently.

From the CrossEntropyLoss documentation, the “class-index” mean reduction
computes what I consider to be a proper “weighted mean” over the per-sample
(per-batch-element) losses. Namely, it sums the weights times the unweighted
losses and divides by the sum of the weights.

However, the “probability” mean reduction sums the weights times the
unweighted losses, but divides, instead, by the number of samples.

Your two values differ (for the example you give) by a factor of 0.2, which is
weight[0]. (I agree that this is an unexpected difference, but it is consistent
with CrossEntropyLoss’s documentation.)

Here is a tweaked version of your script:

import torch
print (torch.__version__)

weight = torch.tensor ([0.2, 0.8])
input = torch.tensor ([[0.1, 0.6], [0.3, 0.4]])
target = torch.tensor ([[1.0, 0.0], [1.0, 0.0]])
sparse_target = torch.tensor ([0, 0])

cce_mean = torch.nn.CrossEntropyLoss (weight = weight, reduction = 'mean')

cce_loss_target = cce_mean (input, target)
cce_loss_sparse = cce_mean (input, sparse_target)
print ('cce_loss_target: ', cce_loss_target)
print ('cce_loss_sparse: ', cce_loss_sparse)

print ('cce_loss_target / weight[0]: ', cce_loss_target / weight[0])

And here is its output:

2.0.1
cce_loss_target:  tensor(0.1718)
cce_loss_sparse:  tensor(0.8592)
cce_loss_target / weight[0]:  tensor(0.8592)

It is tempting to require that the two weighted reductions give the same
results. However, your example is a special case in that your probabilistic
target is either exactly 0.0 or 1.0. There is a legitimate question of how
best to define the weighted reduction for a non-trivial probabilistic target
(such as [0.25, 0.75]).

This is discussed in github issues 61309 and 61044. (But I agree with you
that the current choice of definition is not ideal.)

Best.

K. Frank

Thank you for the detailed explanation.
Now I understand why the difference is happening and how to handle it.