Implementation of torch.nn.BCEWithLogitsLoss

Hi,

I tried to implement BCEWithLogitsLoss by myself.

for example,

def bce_loss(pred, target):
    pred=F.sigmoid(pred)
    loss=torch.mean(-torch.sum(target * torch.log(pred) + (1-target) * torch.log(1-pred)) / target.size(1))

however, the loss is quite larger than torch.nn.BCEWithLogitsLoss

for example,

def bce_loss_pytorch(pred, target):
    m=torch.nn.BCEWithLogitsLoss()
    loss=m(pred, target)

I am not sure what is the main difference between my implementation and torch.nn.BCEWithLogitsLoss.

any idea?

I think the issue is that you have 2 means

  • The first one is that you have the sum and then divide by the size (but you should use target.numel() instead of target.size(1))
  • Second is when you use torch.mean

So 2 options:
Here is your code with torch.mean

def bce_loss(pred, target):
    pred=F.sigmoid(pred)
    loss=torch.mean(-(target * torch.log(pred) + (1-target) * torch.log(1-pred)))
    return loss

And here is your code with manual mean (sum and divide)

def bce_loss(pred, target):
    pred=F.sigmoid(pred)
    loss=torch.sum(-(target * torch.log(pred) + (1-target) * torch.log(1-pred))) / target.numel()
    return loss

That said, they say that using the official torch.nn.BCEWithLogitsLoss() is better, cause although it’s doing sigmoid and then BCE, it does it in such a way that it is numerically stable (i believe the optimization is on the c++ level)

Roy.

Hi Roy and 杜明軒!

As a minor note, you can implement your own BCEWithLogitsLoss
with the same numerical benefits as pytorch’s by replacing the
separate calls to sigmoid() and log() with a call to logsigmoid().

Best.

K. Frank

1 Like