Implementation of Binary cross Entropy?

Hi All,

I want to write a code for label smoothing using BCEWithLogitsLoss .

Q1) Is BCEWithLogitLoss = BCELoss + sigmoid() ?
Q2) While checking the pytorch github docs I found following code in which sigmoid implementation is not there maybe I am looking at wrong Documents ?

Can someone tell me where they write proper BCEWithLogitLoss Code. ??

class BCEWithLogitsLoss(_Loss):

    def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean',
                 pos_weight: Optional[Tensor] = None) -> None:
        super(BCEWithLogitsLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer('weight', weight)
        self.register_buffer('pos_weight', pos_weight)

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.binary_cross_entropy_with_logits(input, target,
                                                  self.weight,
                                                  pos_weight=self.pos_weight)



def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
                                     reduce=None, reduction='mean', pos_weight=None):
    # type: (Tensor, Tensor, Optional[Tensor], Optional[bool], Optional[bool], str, Optional[Tensor]) -> Tensor
    r"""Function that measures Binary Cross Entropy between target and output
    logits.

    See :class:`~torch.nn.BCEWithLogitsLoss` for details.

    Args:
        input: Tensor of arbitrary shape
        target: Tensor of the same shape as input
        weight (Tensor, optional): a manual rescaling weight
            if provided it's repeated to match input tensor shape
        size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
            the losses are averaged over each loss element in the batch. Note that for
            some losses, there multiple elements per sample. If the field :attr:`size_average`
            is set to ``False``, the losses are instead summed for each minibatch. Ignored
            when reduce is ``False``. Default: ``True``
        reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
            losses are averaged or summed over observations for each minibatch depending
            on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
            batch element instead and ignores :attr:`size_average`. Default: ``True``
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
        pos_weight (Tensor, optional): a weight of positive examples.
                Must be a vector with length equal to the number of classes.

    Examples::

         >>> input = torch.randn(3, requires_grad=True)
         >>> target = torch.empty(3).random_(2)
         >>> loss = F.binary_cross_entropy_with_logits(input, target)
         >>> loss.backward()
    """
    if not torch.jit.is_scripting():
        tens_ops = (input, target)
        if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
            return handle_torch_function(
                binary_cross_entropy_with_logits, tens_ops, input, target, weight=weight,
                size_average=size_average, reduce=reduce, reduction=reduction,
                pos_weight=pos_weight)
    if size_average is not None or reduce is not None:
        reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
    else:
        reduction_enum = _Reduction.get_enum(reduction)

    if not (target.size() == input.size()):
        raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))

    return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

Thanks !!!

1 Like

Hello,

  1. Yes equivalent but less stable for BCELoss
  2. The code of the BCEWithLogitsLoss Class can be found in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py
    You will find a call to
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.binary_cross_entropy_with_logits(input, target,
                                                  self.weight,
                                                  pos_weight=self.pos_weight,
                                                  reduction=self.reduction)

The F oject is imported from functionnal.py here : https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py

You will find the function called

def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
                                     reduce=None, reduction='mean', pos_weight=None):

It calls the handle_torch_function in https://github.com/pytorch/pytorch/blob/master/torch/overrides.py
You will find an entry of the function binary_cross_entropy_with_logits in the ret dictionnary wich contain every function that can be overriden in pytorch.
This is the Python implementation of torch_function
More info in https://github.com/pytorch/pytorch/issues/24015

Then the code called is in the C++ File
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Loss.cpp


Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) {
    Tensor loss;
    auto max_val = (-input).clamp_min_(0);
    if (pos_weight.defined()) {
        // pos_weight need to be broadcasted, thus mul(target) is not inplace.
        auto log_weight = (pos_weight - 1).mul(target).add_(1);
        loss = (1 - target).mul_(input).add_(log_weight.mul_(((-max_val).exp_().add_((-input - max_val).exp_())).log_().add_(max_val)));
    } else {
        loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());
    }

    if (weight.defined()) {
        loss.mul_(weight);
    }

    return apply_loss_reduction(loss, reduction);
}

take advantage of the log-sum-exp trick for numerical stability
https://en.wikipedia.org/wiki/LogSumExp

You can compare it with the BCELoss in binary_cross_entropy_out_cpu( function
it is less stable it apply L = -w (y ln(x) + (1-y) ln(1-x))


                // Binary cross entropy tensor is defined by the equation:
                // L = -w (y ln(x) + (1-y) ln(1-x))
                return (target_val - scalar_t(1))
                    * std::max(scalar_t(std::log(scalar_t(1) - input_val)), scalar_t(-100))
                    - target_val * std::max(scalar_t(std::log(input_val)), scalar_t(-100));
2 Likes

Hello Surya and Pytorchtester!

To clarify a bit:

Mathematically, BCEWithLogitsLoss is sigmoid() followed by
BCELoss. But numerically they are different, with BCELoss
numerically less stable.

Elaborating on the above, sigmoid() is not there, because it is
not explicitly part of BCEWithLogitsLoss. It is hiding in the
log (sigmoid()) version of the “log-sum-exp trick,” in this line
from the c++ code that Pytorchtester posted:

loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());

Best.

K. Frank

3 Likes