While tinkering with the official code example for Variational Autoencoders, I experienced some unexpected behaviour with regard to the Binary Cross-Entropy loss. When I use F.binary_cross_entropy in combination with the sigmoid function, the model trains as expected on MNIST. However, when changing to the F.binary_cross_entropy_with_logits function, the loss suddenly becomes arbitrarily small during training and the model no longer produces meaningful results.

# For this loss function the loss becomes arbitrarily small
BCE = F.binary_cross_entropy_with_logits(recon_x, x.view(-1, 784), reduction='sum') / x.shape[0]
# For this loss function, the training works as expected
BCE = F.binary_cross_entropy(torch.sigmoid(recon_x), x.view(-1, 784), reduction='sum') / x.shape[0]

To my understanding, the only difference between the two approaches should be numerical stability. Am I missing something?

I agree with your understanding, and looking at the two lines of
code you posted, I donâ€™t see anything suspicious (although I miss
things all the time â€¦).

If I had to guess, I would guess that you have a typo somewhere
else in your code that causes the two runs to differ.

However, we can start by checking out some basics:

Here I run the example given in the documentation for torch.nn.functional.binary_cross_entropy. (Please note
that I am running this test, for whatever reason, with pytorch 0.3.0.)

And, indeed, the two expressions are the same (up to floating-point
precision).

Is there any way you can capture a specific instance of your recon_x and x.view() so that you can pump what you know
are the same values into your two cross-entropy expressions?

Is there any way you track your loss (your BCE) on a step-by-step
basis so you can see when they first diverge?

The thing that is noteworthy to me is that you say that the less
numerically stable version (regular bce) works, while the more
stable version (bce_with_logits) quits working at some point.
This is backwards of what one might expect.

Perhaps there is something about your model that puts it on
the edge of being poorly behaved. In such a case it could be
plausible that, by happenstance, the bce version stays in a
well-behaved region, but small differences due to using the
bce_with_logits version cause it to drift into a poorly-behaved
region.

If this were the case I would expect that other perturbations
such as starting with different initial weights, or using a different
optimization algorithm or learning rate could cause the training
to â€śrandomlyâ€ť end up being well behaved or poorly behaved.

So:

I would

proofread the code to make sure there isnâ€™t some outright error

check that your model and data are reasonably well behaved
and stable with respect to perturbations of the details of your
training

track your bce and bce_logits step by step to find out where
they first diverge, and drill down with the values that immediately
precede the divergence.

thank you for your elaborate response. I have uploaded an instance of tensors for which the two losses diverge here. This is the resulting output for it:

import torch
x_view = torch.load("x_view.pt")
x_recon = torch.load("x_recon.pt")
BCE1 = torch.nn.functional.binary_cross_entropy_with_logits(x_recon, x_view, reduction='sum')
BCE2 = torch.nn.functional.binary_cross_entropy(torch.sigmoid(x_recon), x_view, reduction='sum')
print("BCE loss with logits: ", BCE1) # -2.3662e+08
print("BCE loss with sigmoid: ", BCE2) # -379848.2500
print("Loss difference: ", BCE1-BCE2) # -2.3624e+08

The x_recon values are all very small (around -1.0e+05). However, I am not sure why the version without logits behaves so differently in this regime. Do you have any insights in this regard?

In the mean time Iâ€™ll try to figure out the minimal changes to the official example that are necessary to reproduce the odd behaviour and post the code here.

First, unfortunately, Iâ€™m not able to load your sample data with
my creaking, wheezing, 0.3.0-version of pytorch. If theyâ€™re not
too large, could you post them as text files?

You say, â€śThe x_recon values are all very small (around -1.0e+05).â€ť
You give a numerical value of -10,000. I would call this â€śa rather
large negative number.â€ť (â€śSmall,â€ť to me, has the connotation of
â€śclose to zero.â€ť)

Anyway, letâ€™s go with -10,000. That is, your logits (inputs)
are rather large negative numbers. So your probabilities
(sigmoid (logit)) are all (positive) numbers quite close to zero.
But (with 32-bit floating-point) they underflow and become
exactly zero. (32-bit sigmoid underflows to zero somewhere
around sigmoid (-90.0).)

Given this, Iâ€™m surprised youâ€™re not getting NaNs (from the log (0.0) inside of binary_cross_entropy()).

Anyway, could you tell us the shape of x_recon and x_view,
as well as the (algebraic) minima and maxima of x_recon
and x_view?

Assuming that your x_recon really have become something
like -10,000, itâ€™s not surprising that youâ€™re getting weird results
(at least for plain bce) â€“ youâ€™ve long since passed into the
range where sigmoid() underflows. (This still doesnâ€™t explain
why youâ€™re getting seeming good results with plain bce, but
things break down with bce_with_logits.)

Ah sorry about that, you are of course right, the x_recon values are large negative numbers. Only the sigmoid is very small.

Both have dimensions torch.Size([100, 784]). The first dimension is the batch size, the second is the number of pixels in a MNIST training image. The extrema are:

So it turns out there are also very big entries in x_recon that saturate the sigmoid towards 1.

I created a demo with minimal changes to the original example that reproduces the odd behaviour here. The relevant change was the normalization of the data to zero mean and unit variance (i.e. transform = transforms.Normalize((0.1307,), (0.3081,))).

The input and target passed to binary_cross_entropy()
are both supposed to be probabilities, that is, numbers between
0 and 1 (with singularities occurring at 0 and 1).

Your target contains negative numbers, which are not valid
probabilities.

(Because you pass your x_recon through sigmoid(), your input
will always contain valid probabilities, although they will sometimes
saturate at the singular 0 and 1.)

So my guess is that your bogus target values are causing your
training to drive your inputs to the large values that saturate sigmoid(). (Why youâ€™re not getting NaNs, I donâ€™t know.)

Given your bogus target, Iâ€™m not surprised that youâ€™re getting
weird results. (I can come up with plausible speculations about
why plain bce and bce_with_logits differ, but thatâ€™s not really the
point.)

Figure out how to fix your inputs to binary_cross_entropy(),
specifically x_view, and see if that cleans things up, or at
least improves the situation.

(I havenâ€™t tried to run your demo, because, among other reasons,
it likely wonâ€™t run with my decrepit pytorch 0.3.0.)

Okay, so normalization to zero mean, unit variance should be removed in this case as it violates the probabilistic interpretation of the images. That solves my issue, thank you!