I’m trying to implement a CrossEntropyLoss layer that reproduces the behavior of the standard torch.nn.CrossEntropyLoss behavior. Currently I get the same loss values as nn.CrossEntropyLoss when I don’t aggregate the loss but when I do aggregate the loss then the result starts to diverge from nn.CrossEntropyLoss. Can anyone tell me how to fix my loss aggregation to match the pytorch implementation? Here’s my code.
class MyCrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, input_, target):
# Some code that I don't have questions about:
...
# Here's he problem:
A loss = - wt * logpt # mb, d1, d2, ..., dk
if self.reduction == 'mean':
return torch.sum(loss) / torch.sum(wt)
elif self.reduction == 'sum':
return torch.sum(loss)
else:
# No aggregation, just return the raw values
return loss
# Simulate a semantic segmentation minibatch with 8 images, 32 classes and 128x128 pixels
logits = torch.rand(size=(8, 32, 128, 128))
weights = torch.rand(32)
truth = torch.LongTensor(size=(8, 128, 128)).random_(-1, 32)
# Experiment 1:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='none')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='none')
print(torch.equal(my_cel(logits, truth), cel(logits, truth))) # True
# Experiment 2:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='sum')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='sum')
my_loss = my_cel(logits, truth)
official_loss = cel(logits, truth)
print(torch.equal(my_loss, official_loss), my_loss, official_loss)
# False tensor(269083.4375) tensor(269083.1562)
# Experiment 3:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='mean')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='mean')
my_loss = my_cel(logits, truth)
official_loss = cel(logits, truth)
print(torch.equal(my_loss, official_loss), my_loss, official_loss)
# False tensor(3.5066) tensor(3.5072)
We can tell from experiment 1 that line A computes the correct weighted losses. Note that the equals function checks exact equality. At this point I’m matching the reference implementation to full precision. This means that the variables wt
and logpt
are almost certainly correct as well, which is good to establish because wt
is part of the mean calculation later.
I’m sort of ok with the results of experiment 2, the sum is only off by a few parts per hundred thousand, but I’d like to fix it to match the official implementation if possible.
Experiment 3 however shows that my mean aggregation is just incorrect. The pytorch nll loss documents how this aggregation is supposed to happen but as far as I can tell my implementation matches that so I’m at a loss how to fix it.
Thanks in advance for your help.