Is there a way to write a custom BCE loss in pytorch?

I am writing a custom BCE in pytorch but in some cases it returns -inf and nan most cases. Which is due to the log function.

bce_loss=y_true*torch.log2(y_pred) + (one_torch-y_pred)*torch.log2(one_torch-y_pred)

Is there a way to rewrite this? Note y_pred is a sigmoid output which is between 0 and 1.

Hello Jesujoba!

It is certainly possible to write your own BCE loss function.

Two comments:

First, the second term in what you’ve posted has y_pred twice
and no y_true. That might just be a typo, but if it’s like that in
your actual code, it won’t work.

Second, to reduce the likelihood of nan, it his better to modify
your BCE loss to take a logit (from -infinity to infinity), rather
than a probability (from 0 to 1).

See, for example, pytorch’s BCEWithLogitsLoss.

This allows you to combine together, internally to your loss
function, the sigmoid and log, and use the “log-sum-exp trick”
to improve numerical stability.

By the way, what is your motivation in writing your own BCE
loss function?


K. Frank