 # How to reproduce a Cross Entropy Loss without losing numerical accuracy?

I get a discrepancy between the values of the losses obtained by the torch. CrossEntropyLoss and CustomeCrossEntropyLoss.

Issue on github

Code to Reproduce

``````import torch
import torch.nn.modules.loss as L

class CustomCrossEntropyLoss(L._Loss):
def __init__(self, reduction=True):
super(CustomCrossEntropyLoss, self).__init__()
self.reduction = reduction

def forward(self, inp, target):
input_target = inp.gather(1, target.view(-1, 1))
input_max, _ = inp.max(dim=1, keepdim=True)
output_exp = torch.exp(inp - input_max)
output_softmax_sum = output_exp.sum(dim=1)
output = -input_target + torch.log(output_softmax_sum).view(-1, 1) + input_max
if self.reduction:
output = output.mean()
return output

torch_ce = torch.nn.CrossEntropyLoss(reduction='none')
custom_ce = CustomCrossEntropyLoss(reduction=False)
batch_size = 128
N_class = 90000
logits = torch.randn((batch_size,N_class))
targets = torch.randint(N_class, (batch_size,))
print((torch_ce(logits, targets).view(-1) - custom_ce(logits,targets).view(-1)).mean())
``````

I get a minimal non-zero discrepancy of the order of 10e-7, which then occurs in gradients and can affect the learning outcome.
How to fix this problem?

Hi Tamerlan!

I haven’t traced through your code in detail, nor looked at
your github issue, but the explanation is likely the following:

Your `logits` tensor is, by default (unless you change the
global default), of `dtype = torch.float32`. 32-bit floating
point numbers have about 7 decimal digits of precision.

Even though (presumably – I haven’t checked your code)
`torch_ce` and `custom_ce` should be mathematically
equivalent, they likely do not perform their calculations in
exactly the same order, so floating-point round-off error
can and will lead to slight differences.

Your `logits` (`randn`) are of order one, so you should
reasonably expect the round-off-error differences you
see of order 10e-7.

There isn’t any “fix” for this – this is how floating-point
arithmetic works.

If you were to go through `torch_ce` and modify your
`custom_ce` so that the two perform their calculations
in exactly the same order, you ought to be able to get
them to produce exactly the same result. But there
would be no point in that – your result would still differ
from the “true” mathematical result by a round-off error
on the order of 10e-7.

If for some reason you need greater precision, you could
use 64-bit floating point arithmetic. Try adding
`dtype = torch.float64` to your `torch.randn()` call.
If my explanation above is correct, you should now see
differences of the order of 10e-15.

Good luck.

K. Frank

1 Like