Implementation of Binary cross Entropy?

Hello,

  1. Yes equivalent but less stable for BCELoss
  2. The code of the BCEWithLogitsLoss Class can be found in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py
    You will find a call to
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.binary_cross_entropy_with_logits(input, target,
                                                  self.weight,
                                                  pos_weight=self.pos_weight,
                                                  reduction=self.reduction)

The F oject is imported from functionnal.py here : https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py

You will find the function called

def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
                                     reduce=None, reduction='mean', pos_weight=None):

It calls the handle_torch_function in https://github.com/pytorch/pytorch/blob/master/torch/overrides.py
You will find an entry of the function binary_cross_entropy_with_logits in the ret dictionnary wich contain every function that can be overriden in pytorch.
This is the Python implementation of torch_function
More info in https://github.com/pytorch/pytorch/issues/24015

Then the code called is in the C++ File
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Loss.cpp


Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) {
    Tensor loss;
    auto max_val = (-input).clamp_min_(0);
    if (pos_weight.defined()) {
        // pos_weight need to be broadcasted, thus mul(target) is not inplace.
        auto log_weight = (pos_weight - 1).mul(target).add_(1);
        loss = (1 - target).mul_(input).add_(log_weight.mul_(((-max_val).exp_().add_((-input - max_val).exp_())).log_().add_(max_val)));
    } else {
        loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());
    }

    if (weight.defined()) {
        loss.mul_(weight);
    }

    return apply_loss_reduction(loss, reduction);
}

take advantage of the log-sum-exp trick for numerical stability
https://en.wikipedia.org/wiki/LogSumExp

You can compare it with the BCELoss in binary_cross_entropy_out_cpu( function
it is less stable it apply L = -w (y ln(x) + (1-y) ln(1-x))


                // Binary cross entropy tensor is defined by the equation:
                // L = -w (y ln(x) + (1-y) ln(1-x))
                return (target_val - scalar_t(1))
                    * std::max(scalar_t(std::log(scalar_t(1) - input_val)), scalar_t(-100))
                    - target_val * std::max(scalar_t(std::log(input_val)), scalar_t(-100));
2 Likes