Binary Cross Entropy with logits does not work as expected

While tinkering with the official code example for Variational Autoencoders, I experienced some unexpected behaviour with regard to the Binary Cross-Entropy loss. When I use F.binary_cross_entropy in combination with the sigmoid function, the model trains as expected on MNIST. However, when changing to the F.binary_cross_entropy_with_logits function, the loss suddenly becomes arbitrarily small during training and the model no longer produces meaningful results.

# For this loss function the loss becomes arbitrarily small 
BCE = F.binary_cross_entropy_with_logits(recon_x, x.view(-1, 784), reduction='sum') / x.shape[0]

# For this loss function, the training works as expected
BCE = F.binary_cross_entropy(torch.sigmoid(recon_x), x.view(-1, 784), reduction='sum') / x.shape[0]

To my understanding, the only difference between the two approaches should be numerical stability. Am I missing something?

Hello Simon!

I agree with your understanding, and looking at the two lines of
code you posted, I don’t see anything suspicious (although I miss
things all the time …).

If I had to guess, I would guess that you have a typo somewhere
else in your code that causes the two runs to differ.

However, we can start by checking out some basics:

Here I run the example given in the documentation for
torch.nn.functional.binary_cross_entropy. (Please note
that I am running this test, for whatever reason, with pytorch 0.3.0.)

Here is the script:

import torch
print (torch.__version__)
torch.manual_seed (2019)
input =  torch.autograd.Variable (torch.randn ((3, 2)))
print (input)
target = torch.autograd.Variable (torch.rand ((3, 2)))
print (target)
loss_plain = torch.nn.functional.binary_cross_entropy (torch.sigmoid (input), target)
print (loss_plain)
loss_logits = torch.nn.functional.binary_cross_entropy_with_logits (input, target)
print (loss_logits)
print (loss_plain - loss_logits)

And here is the output:

>>> import torch
>>> print (torch.__version__)
0.3.0b0+591e73e
>>> torch.manual_seed (2019)
<torch._C.Generator object at 0x00000207759A60F0>
>>> input =  torch.autograd.Variable (torch.randn ((3, 2)))
>>> print (input)
Variable containing:
-0.1187  0.2110
 0.7463 -0.6136
-0.1186  1.5565
[torch.FloatTensor of size 3x2]

>>> target = torch.autograd.Variable (torch.rand ((3, 2)))
>>> print (target)
Variable containing:
 0.7628  0.0721
 0.2208  0.3979
 0.6338  0.1922
[torch.FloatTensor of size 3x2]

>>> loss_plain = torch.nn.functional.binary_cross_entropy (torch.sigmoid (input), target)
>>> print (loss_plain)
Variable containing:
 0.8868
[torch.FloatTensor of size 1]

>>> loss_logits = torch.nn.functional.binary_cross_entropy_with_logits (input, target)
>>> print (loss_logits)
Variable containing:
 0.8868
[torch.FloatTensor of size 1]

>>> print (loss_plain - loss_logits)
Variable containing:
1.00000e-08 *
 -5.9605
[torch.FloatTensor of size 1]

And, indeed, the two expressions are the same (up to floating-point
precision).

Is there any way you can capture a specific instance of your
recon_x and x.view() so that you can pump what you know
are the same values into your two cross-entropy expressions?

Is there any way you track your loss (your BCE) on a step-by-step
basis so you can see when they first diverge?

The thing that is noteworthy to me is that you say that the less
numerically stable version (regular bce) works, while the more
stable version (bce_with_logits) quits working at some point.
This is backwards of what one might expect.

Perhaps there is something about your model that puts it on
the edge of being poorly behaved. In such a case it could be
plausible that, by happenstance, the bce version stays in a
well-behaved region, but small differences due to using the
bce_with_logits version cause it to drift into a poorly-behaved
region.

If this were the case I would expect that other perturbations
such as starting with different initial weights, or using a different
optimization algorithm or learning rate could cause the training
to “randomly” end up being well behaved or poorly behaved.

So:

I would

  1. proofread the code to make sure there isn’t some outright error
  2. check that your model and data are reasonably well behaved
    and stable with respect to perturbations of the details of your
    training
  3. track your bce and bce_logits step by step to find out where
    they first diverge, and drill down with the values that immediately
    precede the divergence.

Have fun!

K. Frank

1 Like

Hello KFrank,

thank you for your elaborate response. I have uploaded an instance of tensors for which the two losses diverge here. This is the resulting output for it:

import torch

x_view  = torch.load("x_view.pt")
x_recon = torch.load("x_recon.pt")

BCE1 = torch.nn.functional.binary_cross_entropy_with_logits(x_recon, x_view, reduction='sum')
BCE2 = torch.nn.functional.binary_cross_entropy(torch.sigmoid(x_recon), x_view, reduction='sum')

print("BCE loss with logits: ", BCE1)  # -2.3662e+08
print("BCE loss with sigmoid: ", BCE2) # -379848.2500
print("Loss difference: ", BCE1-BCE2)  # -2.3624e+08

The x_recon values are all very small (around -1.0e+05). However, I am not sure why the version without logits behaves so differently in this regime. Do you have any insights in this regard?

In the mean time I’ll try to figure out the minimal changes to the official example that are necessary to reproduce the odd behaviour and post the code here.

Hi Simon!

First, unfortunately, I’m not able to load your sample data with
my creaking, wheezing, 0.3.0-version of pytorch. If they’re not
too large, could you post them as text files?

You say, “The x_recon values are all very small (around -1.0e+05).”
You give a numerical value of -10,000. I would call this “a rather
large negative number.” (“Small,” to me, has the connotation of
“close to zero.”)

Anyway, let’s go with -10,000. That is, your logits (inputs)
are rather large negative numbers. So your probabilities
(sigmoid (logit)) are all (positive) numbers quite close to zero.
But (with 32-bit floating-point) they underflow and become
exactly zero. (32-bit sigmoid underflows to zero somewhere
around sigmoid (-90.0).)

Given this, I’m surprised you’re not getting NaNs (from the
log (0.0) inside of binary_cross_entropy()).

Anyway, could you tell us the shape of x_recon and x_view,
as well as the (algebraic) minima and maxima of x_recon
and x_view?

Assuming that your x_recon really have become something
like -10,000, it’s not surprising that you’re getting weird results
(at least for plain bce) – you’ve long since passed into the
range where sigmoid() underflows. (This still doesn’t explain
why you’re getting seeming good results with plain bce, but
things break down with bce_with_logits.)

Best.

K. Frank

Ah sorry about that, you are of course right, the x_recon values are large negative numbers. Only the sigmoid is very small.

Both have dimensions torch.Size([100, 784]). The first dimension is the batch size, the second is the number of pixels in a MNIST training image. The extrema are:

torch.min(x_recon) # -16971.4434
torch.max(x_recon) # 15807.5469
torch.min(x_view) # -0.4242
torch.max(x_view) # 2.8215)

So it turns out there are also very big entries in x_recon that saturate the sigmoid towards 1.

I created a demo with minimal changes to the original example that reproduces the odd behaviour here. The relevant change was the normalization of the data to zero mean and unit variance (i.e. transform = transforms.Normalize((0.1307,), (0.3081,))).

Hi Simon!

Okay, this makes more sense.

The input and target passed to binary_cross_entropy()
are both supposed to be probabilities, that is, numbers between
0 and 1 (with singularities occurring at 0 and 1).

Your target contains negative numbers, which are not valid
probabilities.

(Because you pass your x_recon through sigmoid(), your input
will always contain valid probabilities, although they will sometimes
saturate at the singular 0 and 1.)

So my guess is that your bogus target values are causing your
training to drive your inputs to the large values that saturate
sigmoid(). (Why you’re not getting NaNs, I don’t know.)

Given your bogus target, I’m not surprised that you’re getting
weird results. (I can come up with plausible speculations about
why plain bce and bce_with_logits differ, but that’s not really the
point.)

Figure out how to fix your inputs to binary_cross_entropy(),
specifically x_view, and see if that cleans things up, or at
least improves the situation.

(I haven’t tried to run your demo, because, among other reasons,
it likely won’t run with my decrepit pytorch 0.3.0.)

Good luck.

K. Frank

1 Like

Okay, so normalization to zero mean, unit variance should be removed in this case as it violates the probabilistic interpretation of the images. That solves my issue, thank you!