[SOLVED] Class Weighed Binary Crossentropy not working, even with equal weights


I have tried using the following custom loss class:

# class BCEWithLogitsLoss():
class CustomWBCE():
    def __init__(self, class_weights=None, **kwargs):
        self.class_weights = class_weights

    def __call__(self, output:Tensor, target:Tensor, **kwargs)->Rank0Tensor:
        output = torch.sigmoid(output)
        if output.min() <= 0 or output.max() >= 1:
#             print('Wrong value in output, will give loss nan')
            output = torch.clamp(torch.sigmoid(output),min=1e-8,max=1 - 1e-8)
        if self.class_weights is not None:
            assert len(self.class_weights) == 2

            loss = self.class_weights[1] * (target * torch.log(output)) + \
                   self.class_weights[0] * ((1 - target) * torch.log(1 - output))
            loss = target * torch.log(output) + (1 - target) * torch.log(1 - output)

        loss = torch.neg(torch.mean(loss))

        return loss

The reason I apply the sigmoid is because I am using this in a fast.ai Learner, which does not apply an activation function at the end of the network by default. Sometimes my model outputed a very small value (-136 for example) and torch.sigmoid’s result was 0., which led to a -inf in the torch.log; that’s why I added the torch.clamp.

I have tried it both with weights computed with the formula:

total = negative + positive
w0 = positive / total
w1 = negative / total

And with weights (0.5, 0.5) to test it out.

Both times I got very poor results:

AUC scores (this ranges from 0 to 1, 0.5 being random) after training with different weights:

  • Calculated weights (0.20, 0.80): 0.51 AUC
  • Equal weights (0.5, 0.5): 0.45 AUC (worse than random somehow)

Now, if I try to train the model in the same way, but using

loss_bce = fastai.layers.BCEWithLogitsFlat() (flattens the tensors before applying torch.nn.modules.loss.BCEWithLogitsFlat) it gets to 0.83 AUC.

The model is a DenseNet121 based binary classifier; I trained it on X-Rays from the CheXPert14 dataset.
I am on PyTorch version 1.0.1 and fast.ai version 1.0.52.

Do you have any idea what I’m doing wrong?

Hi Andrei!

First, to answer the question I think you’re asking:

You should be using (as in the comment in your code)

BCEWithLogitsLoss supports sample weights, which you
can use for class weights.

Let’s say you have class weight w_1 for class 1, and w_0
for class 0. Let w_n be the sample weight for sample n.
Simply set w_n = w_1 if y_n = 1, and w_n = w_0 if
y_n = 0. (This assumes that the y_n are either 0 or 1, as
they should be if they are binary class labels.)

Now some comments:

Note that using class weights w_1 = w_0 = 1/2 doesn’t give
you the same result as an unweighted loss function. It gives
you 1/2 the unweighted loss function. (The loss for each sample
is multiplied by 1/2). This doesn’t matter a lot, but, for example,
with plain-vanilla stochastic-gradient-descent optimization, it
has the effect of reducing your learning rate by a factor of 1/2.

Instead of clamping the sigmoid of your output, you should be
using torch.nn.LogSigmoid. This avoids the problem of
large negative --> sigmoid --> 0 --> log --> -inf.

In general, when you are testing / debugging something like
this, instead of running your full training code with a “default”
value like weights = (0.5, 0.5), you should try calling your
function on a single sample, with your default value and
compare the single numerical result with the result of the
standard unweighted function you are trying to mimic (in
this case BCEWithLogitsLoss). Only when you are happy
that you have that working should you try running a single
batch, and when that is working, try the training.

Lastly, I think this discussion – especially the comment about
avoiding clamping – applies to your earlier thread:

and its linked thread:

Best regards.

K. Frank

Thank you for the detailed explanation, and tips for debugging. I am trying the BCEWithLogitsFlat (fast.ai wrap for BCEWithLogits) right now. So far it seems to be as good as its unweighted version.