Hello,
- Yes equivalent but less stable for BCELoss
- The code of the BCEWithLogitsLoss Class can be found in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py
You will find a call to
def forward(self, input: Tensor, target: Tensor) -> Tensor:
return F.binary_cross_entropy_with_logits(input, target,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
The F oject is imported from functionnal.py here : https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
You will find the function called
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
reduce=None, reduction='mean', pos_weight=None):
It calls the handle_torch_function in https://github.com/pytorch/pytorch/blob/master/torch/overrides.py
You will find an entry of the function binary_cross_entropy_with_logits in the ret dictionnary wich contain every function that can be overriden in pytorch.
This is the Python implementation of torch_function
More info in https://github.com/pytorch/pytorch/issues/24015
Then the code called is in the C++ File
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Loss.cpp
Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) {
Tensor loss;
auto max_val = (-input).clamp_min_(0);
if (pos_weight.defined()) {
// pos_weight need to be broadcasted, thus mul(target) is not inplace.
auto log_weight = (pos_weight - 1).mul(target).add_(1);
loss = (1 - target).mul_(input).add_(log_weight.mul_(((-max_val).exp_().add_((-input - max_val).exp_())).log_().add_(max_val)));
} else {
loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());
}
if (weight.defined()) {
loss.mul_(weight);
}
return apply_loss_reduction(loss, reduction);
}
take advantage of the log-sum-exp trick for numerical stability
https://en.wikipedia.org/wiki/LogSumExp
You can compare it with the BCELoss in binary_cross_entropy_out_cpu( function
it is less stable it apply L = -w (y ln(x) + (1-y) ln(1-x))
// Binary cross entropy tensor is defined by the equation:
// L = -w (y ln(x) + (1-y) ln(1-x))
return (target_val - scalar_t(1))
* std::max(scalar_t(std::log(scalar_t(1) - input_val)), scalar_t(-100))
- target_val * std::max(scalar_t(std::log(input_val)), scalar_t(-100));