Creating a custom BCE with logit loss function

I want to create a custom loss function for multi-label classification. The idea is to weigh the positive and negative labels differently. For this, I am making use of this custom code implementation.

class WeightedBCEWithLogitLoss(nn.Module):
    def __init__(self, pos_weight, neg_weight):
        super(WeightedBCEWithLogitLoss, self).__init__()
        self.register_buffer('neg_weight', neg_weight)
        self.register_buffer('pos_weight', pos_weight)

    def forward(self, input, target):
        assert input.shape == target.shape, "The loss function received invalid input shapes"
        y_hat = torch.sigmoid(input + 1e-8)
        loss = -1.0 * (self.pos_weight * target * torch.log(y_hat + 1e-6) + self.neg_weight * (1 - target) * torch.log(1 - y_hat + 1e-6))
        # Account for 0 times inf which leads to nan
        loss[torch.isnan(loss)] = 0
        # We average across each of the extra attribute dimensions to generalize it
        loss = loss.mean(dim=1)
        # We use mean reduction for our task
        return loss.mean()

I started getting nan values which I realized happened because of 0 times inf multiplication. I handled it as shown in the figure. Next, I again saw getting inf as the error value and corrected it by adding 1e-6 to the log (I tried with 1e-8 but that still gave me inf error value).

It would be great if someone can take a look and suggest further improvements and rectify any more bugs visible here.

Hi Chinmay!

Do be aware that pytorch’s BCEWithLogitsLoss supports a
pos_weight constructor argument that will do what you want.
So unless this is a learning exercise, you should simply use
BCEWithLogitsLoss.

The 1e-8 doesn’t do anything useful here. sigmoid() is very
well behaved (and equal to 0.5) when its argument is equal to
zero. So there’s no need or benefit in trying to move the argument
a little bit away from zero. (Furthermore, -1.e-8 is a perfectly
valid logit and argument to your loss function and your “fix” just
moves it to zero – not that anything bad happens at zero.)

Here you apply log() to sigmoid(). This is a source of numerical
instability. You should use the log-sum-exp “trick” to compute
log (sigmoid()). (This is what pytorch’s BCEWithLogitsLoss
does internally.)

With the code you posted, I don’t see why you would be getting
nans or infs. The 1.e-6 that you add to your log() functions
should protect against that.

Best.

K. Frank

1 Like