How to avoid -inf while implementing binary cross entropy loss?

I found sometimes the bce loss hit -inf with my bce loss. My implementation is shown below:

# imagine the model ouputs a tensor with shape (N, 2)
# GT is one-hot encoding of shape (N, 2)
prob = torch.softmax(output, dim=1)
loss = torch.sum(-torch.log(prob) * gt, dim=1)
loss = torch.mean(loss)

In some cases, the model is overconfident and outputs two values with a very large difference, e.g. tensor([[-100, 100]]). The -100 will cause the softmax function give 0 and then torch.log() gives -inf.

Is there any way to deal with it ? Thanks a lot.


You could use the built-in loss functions, such as nn.BCEWithLogitsLoss or nn.CrossEntropyLoss, which increase the numerical stability e.g. via logsumexp or by clipping.

Thanks @ptrblck,
I‘ve tried the built-in functions and they are indeed numerical stable.
But I’m curious that why is my implementation not numerical stable. How to make it more stable?


This blog post and this Wikipedia article explain the subtraction of the max value for numerical stability.

1 Like