Sigmoid vs Binary Cross Entropy Loss

In my torch model, the last layer is a torch.nn.Sigmoid() and the loss is the torch.nn.BCELoss.
In the training step, the following error has occurred:

RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
or torch.nn.BCEWithLogitsLoss.  binary_cross_entropy_with_logits and BCEWithLogits are
safe to autocast.

However, when trying to reproduce this error while computing the loss and backpropagation, everything goes correctly:

import torch
from torch import nn

# last layer
sigmoid = nn.Sigmoid()
# loss
bce_loss = nn.BCELoss()

# the true classes
true_cls = torch.tensor([

# model prediction classes
pred_cls = sigmoid(
# tensor([[0.6213],
#         [0.6183]], grad_fn=<SigmoidBackward>)

out = bce_loss(pred_cls, true_cls)
# tensor(0.7258, grad_fn=<BinaryCrossEntropyBackward>)


What am i missing?
I appreciate any help you can provide.

1 Like

In the posted code snippet you are not using mixed-precision training and in particular torch.cuda.amp.autocast, which raises the initial error.
The reason for this error is the numerical stability of sigmoid + nn.BCELoss, which is already less stable than nn.BCEWithLogitsLoss in float32 and could suffer more when using mixed-precision training.

1 Like

The posted code doesn’t raise an error anymore in Pytorch, but it still raise the same error when I use Pytorch Lightning with mixed precision. So indeed we should always use nn.BCEWithLogitsLoss other than sigmoid + nn.BCELoss.

I don’t think the posted code was using amp, but this code would still raise the error:

import torch
import torch.nn as nn

x = torch.sigmoid(torch.randn(10, 1, requires_grad=True, device='cuda'))
y = torch.randint(0, 2, (10, 1), device='cuda').float()
criterion = nn.BCELoss()

with torch.cuda.amp.autocast():
    loss = criterion(x, y)


In any case, you are right that nn.BCEWithLogitsLoss should be used for numerical stability reasons.

1 Like