How does nn.CrossEntropyLoss aggregate the loss?

I’m trying to implement a CrossEntropyLoss layer that reproduces the behavior of the standard torch.nn.CrossEntropyLoss behavior. Currently I get the same loss values as nn.CrossEntropyLoss when I don’t aggregate the loss but when I do aggregate the loss then the result starts to diverge from nn.CrossEntropyLoss. Can anyone tell me how to fix my loss aggregation to match the pytorch implementation? Here’s my code.

class MyCrossEntropyLoss(nn.Module):
    def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
        super().__init__()
        self.weight = weight
        self.ignore_index = ignore_index
        self.reduction = reduction
        
    def forward(self, input_, target):
        # Some code that I don't have questions about:
        ...
        # Here's he problem:
A       loss = - wt * logpt    # mb, d1, d2, ..., dk    
        if self.reduction == 'mean':
            return torch.sum(loss) / torch.sum(wt)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            # No aggregation, just return the raw values
            return loss

# Simulate a semantic segmentation minibatch with 8 images, 32 classes and 128x128 pixels
logits = torch.rand(size=(8, 32, 128, 128))
weights = torch.rand(32)
truth = torch.LongTensor(size=(8, 128, 128)).random_(-1, 32)

# Experiment 1:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='none')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='none')
print(torch.equal(my_cel(logits, truth), cel(logits, truth))) # True

# Experiment 2:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='sum')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='sum')

my_loss = my_cel(logits, truth)
official_loss = cel(logits, truth)
print(torch.equal(my_loss, official_loss), my_loss, official_loss)
# False tensor(269083.4375) tensor(269083.1562)

# Experiment 3:
my_cel = MyCrossEntropyLoss(weight=weights, ignore_index=-1, reduction='mean')
cel = nn.CrossEntropyLoss(weight=weights, ignore_index=-1, reduction='mean')

my_loss = my_cel(logits, truth)
official_loss = cel(logits, truth)
print(torch.equal(my_loss, official_loss), my_loss, official_loss)
# False tensor(3.5066) tensor(3.5072)

We can tell from experiment 1 that line A computes the correct weighted losses. Note that the equals function checks exact equality. At this point I’m matching the reference implementation to full precision. This means that the variables wt and logpt are almost certainly correct as well, which is good to establish because wt is part of the mean calculation later.

I’m sort of ok with the results of experiment 2, the sum is only off by a few parts per hundred thousand, but I’d like to fix it to match the official implementation if possible.

Experiment 3 however shows that my mean aggregation is just incorrect. The pytorch nll loss documents how this aggregation is supposed to happen but as far as I can tell my implementation matches that so I’m at a loss how to fix it.

Thanks in advance for your help.

Your reductions don’t seem to use the passed weight tensor.
Have a look at this post and let me know, if this would solve the issue.

Sorry if this wasn’t clear. The wt variable is the weight for the true label. I use it in both computing the raw loss and the mean aggregation.

nn.CrossEntropyLoss would also apply the weight for the current true label, but would additionally normalize with it as shown in my code snippet, wouldn’t it?

Yes. We’re both normalizing by dividing by the sum of the weights for the true labels.

Could you post the complete code then, please, as it’s unclear how wt is defined in the code, which makes debugging hard. :slight_smile:

This is the full code. Remember that we know that the problem is in aggregation because experiments show that the non-aggregated loss matches the output of nn.CrossEntropyLoss exactly.

class MyCrossEntropyLoss(nn.Module):
    def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
        super().__init__()
        self.weight = weight
        self.ignore_index = ignore_index
        self.reduction = reduction
        
    def forward(self, input_, target):
        ignored = target == self.ignore_index              # mb, d1, d2, ..., dk
        # Set the ignored labels to zero. We will later multiply these by zero
        # weights to ignore them. 
        target = target.clone()
        target[ignored] = 0
        ignored = ignored.type(torch.FloatTensor)
        logp = F.log_softmax(input_, dim=1)                # mb, C, d1, d2, ..., dk
        # Gather the predictions for the true labels
        logpt = torch.gather(logp, 1, target.unsqueeze(1)) # mb, 1, d1, d2, ..., dk
        logpt = logpt.squeeze(1)                           # mb, d1, d2, ..., dk
        if self.weight is not None:
            w = self.weight.expand(target.shape + self.weight.shape) # mb, d1, d2, ..., dk, C
            # Construct the permutation that will move the channels from the end to
            # index 1. There has got to be an easier way
            permutation = (0, -1) + tuple(range(1, len(w.shape)-1)) 
            w = w.permute(permutation)                               # mb, C, d1, d2, ..., dk
            # Gather the weights for the true labels.
            wt = torch.gather(w, 1, target.unsqueeze(1))             # mb, 1 d1, d2, ..., dk
            wt = wt.squeeze(1)                                       # mb, d1, d2, ..., dk
            wt *= 1 - ignored                                        # mb, d1, d2, ..., dk
        else:
            wt = 1 - ignored
        loss = - wt * logpt    # mb, d1, d2, ..., dk

        if self.reduction == 'mean':
            return torch.sum(loss / torch.sum(wt))
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            return loss