I’m trying to implement a CrossEntropyLoss layer that reproduces the behavior of the standard torch.nn.CrossEntropyLoss behavior. Currently I get the same loss values as nn.CrossEntropyLoss when I don’t aggregate the loss but when I do aggregate the loss then the result starts to diverge from nn.CrossEntropyLoss. Can anyone tell me how to fix my loss aggregation to match the pytorch implementation? Here’s my code.

```
class MyCrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, input_, target):
# Some code that I don't have questions about:
...
# Here's he problem:
A loss = - wt * logpt # mb, d1, d2, ..., dk
if self.reduction == 'mean':
return torch.sum(loss) / torch.sum(wt)
elif self.reduction == 'sum':
return torch.sum(loss)
else:
# No aggregation, just return the raw values
return loss
# Simulate a semantic segmentation minibatch with 8 images, 32 classes and 128x128 pixels
logits = torch.rand(size=(8, 32, 128, 128))
weights = torch.rand(32)
truth = torch.LongTensor(size=(8, 128, 128)).random_(-1, 32)
# Experiment 1:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='none')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='none')
print(torch.equal(my_cel(logits, truth), cel(logits, truth))) # True
# Experiment 2:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='sum')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='sum')
my_loss = my_cel(logits, truth)
official_loss = cel(logits, truth)
print(torch.equal(my_loss, official_loss), my_loss, official_loss)
# False tensor(269083.4375) tensor(269083.1562)
# Experiment 3:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='mean')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='mean')
my_loss = my_cel(logits, truth)
official_loss = cel(logits, truth)
print(torch.equal(my_loss, official_loss), my_loss, official_loss)
# False tensor(3.5066) tensor(3.5072)
```

We can tell from experiment 1 that line A computes the correct weighted losses. Note that the equals function checks exact equality. At this point I’m matching the reference implementation to full precision. This means that the variables `wt`

and `logpt`

are almost certainly correct as well, which is good to establish because `wt`

is part of the mean calculation later.

I’m sort of ok with the results of experiment 2, the sum is only off by a few parts per hundred thousand, but I’d like to fix it to match the official implementation if possible.

Experiment 3 however shows that my mean aggregation is just incorrect. The pytorch nll loss documents how this aggregation is supposed to happen but as far as I can tell my implementation matches that so I’m at a loss how to fix it.

Thanks in advance for your help.