# How does nn.CrossEntropyLoss aggregate the loss?

I’m trying to implement a CrossEntropyLoss layer that reproduces the behavior of the standard torch.nn.CrossEntropyLoss behavior. Currently I get the same loss values as nn.CrossEntropyLoss when I don’t aggregate the loss but when I do aggregate the loss then the result starts to diverge from nn.CrossEntropyLoss. Can anyone tell me how to fix my loss aggregation to match the pytorch implementation? Here’s my code.

``````class MyCrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction

def forward(self, input_, target):
# Some code that I don't have questions about:
...
# Here's he problem:
A       loss = - wt * logpt    # mb, d1, d2, ..., dk
if self.reduction == 'mean':
elif self.reduction == 'sum':
else:
# No aggregation, just return the raw values
return loss

# Simulate a semantic segmentation minibatch with 8 images, 32 classes and 128x128 pixels
logits = torch.rand(size=(8, 32, 128, 128))
weights = torch.rand(32)
truth = torch.LongTensor(size=(8, 128, 128)).random_(-1, 32)

# Experiment 1:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='none')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='none')
print(torch.equal(my_cel(logits, truth), cel(logits, truth))) # True

# Experiment 2:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='sum')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='sum')

my_loss = my_cel(logits, truth)
official_loss = cel(logits, truth)
print(torch.equal(my_loss, official_loss), my_loss, official_loss)
# False tensor(269083.4375) tensor(269083.1562)

# Experiment 3:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='mean')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='mean')

my_loss = my_cel(logits, truth)
official_loss = cel(logits, truth)
print(torch.equal(my_loss, official_loss), my_loss, official_loss)
# False tensor(3.5066) tensor(3.5072)
``````

We can tell from experiment 1 that line A computes the correct weighted losses. Note that the equals function checks exact equality. At this point I’m matching the reference implementation to full precision. This means that the variables `wt` and `logpt` are almost certainly correct as well, which is good to establish because `wt` is part of the mean calculation later.

I’m sort of ok with the results of experiment 2, the sum is only off by a few parts per hundred thousand, but I’d like to fix it to match the official implementation if possible.

Experiment 3 however shows that my mean aggregation is just incorrect. The pytorch nll loss documents how this aggregation is supposed to happen but as far as I can tell my implementation matches that so I’m at a loss how to fix it.

Your reductions don’t seem to use the passed `weight` tensor.
Have a look at this post and let me know, if this would solve the issue.

Sorry if this wasn’t clear. The `wt` variable is the weight for the true label. I use it in both computing the raw loss and the mean aggregation.

`nn.CrossEntropyLoss` would also apply the weight for the current true label, but would additionally normalize with it as shown in my code snippet, wouldn’t it?

Yes. We’re both normalizing by dividing by the sum of the weights for the true labels.

Could you post the complete code then, please, as it’s unclear how `wt` is defined in the code, which makes debugging hard. This is the full code. Remember that we know that the problem is in aggregation because experiments show that the non-aggregated loss matches the output of nn.CrossEntropyLoss exactly.

``````class MyCrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction

def forward(self, input_, target):
ignored = target == self.ignore_index              # mb, d1, d2, ..., dk
# Set the ignored labels to zero. We will later multiply these by zero
# weights to ignore them.
target = target.clone()
target[ignored] = 0
ignored = ignored.type(torch.FloatTensor)
logp = F.log_softmax(input_, dim=1)                # mb, C, d1, d2, ..., dk
# Gather the predictions for the true labels
logpt = torch.gather(logp, 1, target.unsqueeze(1)) # mb, 1, d1, d2, ..., dk
logpt = logpt.squeeze(1)                           # mb, d1, d2, ..., dk
if self.weight is not None:
w = self.weight.expand(target.shape + self.weight.shape) # mb, d1, d2, ..., dk, C
# Construct the permutation that will move the channels from the end to
# index 1. There has got to be an easier way
permutation = (0, -1) + tuple(range(1, len(w.shape)-1))
w = w.permute(permutation)                               # mb, C, d1, d2, ..., dk
# Gather the weights for the true labels.
wt = torch.gather(w, 1, target.unsqueeze(1))             # mb, 1 d1, d2, ..., dk
wt = wt.squeeze(1)                                       # mb, d1, d2, ..., dk
wt *= 1 - ignored                                        # mb, d1, d2, ..., dk
else:
wt = 1 - ignored
loss = - wt * logpt    # mb, d1, d2, ..., dk

if self.reduction == 'mean':