I get a minimal non-zero discrepancy of the order of 10e-7, which then occurs in gradients and can affect the learning outcome.
How to fix this problem?

I havenâ€™t traced through your code in detail, nor looked at
your github issue, but the explanation is likely the following:

Your logits tensor is, by default (unless you change the
global default), of dtype = torch.float32. 32-bit floating
point numbers have about 7 decimal digits of precision.

Even though (presumably â€“ I havenâ€™t checked your code) torch_ce and custom_ce should be mathematically
equivalent, they likely do not perform their calculations in
exactly the same order, so floating-point round-off error
can and will lead to slight differences.

Your logits (randn) are of order one, so you should
reasonably expect the round-off-error differences you
see of order 10e-7.

There isnâ€™t any â€śfixâ€ť for this â€“ this is how floating-point
arithmetic works.

If you were to go through torch_ce and modify your custom_ce so that the two perform their calculations
in exactly the same order, you ought to be able to get
them to produce exactly the same result. But there
would be no point in that â€“ your result would still differ
from the â€śtrueâ€ť mathematical result by a round-off error
on the order of 10e-7.

If for some reason you need greater precision, you could
use 64-bit floating point arithmetic. Try adding dtype = torch.float64 to your torch.randn() call.
If my explanation above is correct, you should now see
differences of the order of 10e-15.