How to reproduce a Cross Entropy Loss without losing numerical accuracy?

I get a discrepancy between the values of the losses obtained by the torch. CrossEntropyLoss and CustomeCrossEntropyLoss.

Issue on github

Code to Reproduce

import torch
import torch.nn.modules.loss as L  

class CustomCrossEntropyLoss(L._Loss):
    def __init__(self, reduction=True):
        super(CustomCrossEntropyLoss, self).__init__()
        self.reduction = reduction

     def forward(self, inp, target):
         input_target = inp.gather(1, target.view(-1, 1))
         input_max, _ = inp.max(dim=1, keepdim=True)
         output_exp = torch.exp(inp - input_max)
         output_softmax_sum = output_exp.sum(dim=1)
         output = -input_target + torch.log(output_softmax_sum).view(-1, 1) + input_max
         if self.reduction:
             output = output.mean()
         return output

 torch_ce = torch.nn.CrossEntropyLoss(reduction='none')
 custom_ce = CustomCrossEntropyLoss(reduction=False)
 batch_size = 128
 N_class = 90000
 logits = torch.randn((batch_size,N_class))
 targets = torch.randint(N_class, (batch_size,))
 print((torch_ce(logits, targets).view(-1) - custom_ce(logits,targets).view(-1)).mean())

I get a minimal non-zero discrepancy of the order of 10e-7, which then occurs in gradients and can affect the learning outcome.
How to fix this problem?

Hi Tamerlan!

I haven’t traced through your code in detail, nor looked at
your github issue, but the explanation is likely the following:

Your logits tensor is, by default (unless you change the
global default), of dtype = torch.float32. 32-bit floating
point numbers have about 7 decimal digits of precision.

Even though (presumably – I haven’t checked your code)
torch_ce and custom_ce should be mathematically
equivalent, they likely do not perform their calculations in
exactly the same order, so floating-point round-off error
can and will lead to slight differences.

Your logits (randn) are of order one, so you should
reasonably expect the round-off-error differences you
see of order 10e-7.

There isn’t any “fix” for this – this is how floating-point
arithmetic works.

If you were to go through torch_ce and modify your
custom_ce so that the two perform their calculations
in exactly the same order, you ought to be able to get
them to produce exactly the same result. But there
would be no point in that – your result would still differ
from the “true” mathematical result by a round-off error
on the order of 10e-7.

If for some reason you need greater precision, you could
use 64-bit floating point arithmetic. Try adding
dtype = torch.float64 to your torch.randn() call.
If my explanation above is correct, you should now see
differences of the order of 10e-15.

Good luck.

K. Frank

1 Like