Numerical stability of BCEWithLogitsLoss

What are differences with respect to numerical stability, for the implementation of binary cross entropy with logits loss

in PyTorch as defined here

max_val = (-input).clamp(min=0)
loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()

versus formulation used in tensorflow, defined here

max(x, 0) - x * z + log(1 + exp(-abs(x)))

Are there any pros and cons of choosing one formulation over the other ?

the first formulation is a two-pass formulation that is generally known to be more stable.
This page might help understand it better https://en.wikipedia.org/wiki/LogSumExp

3 Likes