Use class weight with Binary Cross Entropy Loss

Hello,

I am doing a segmentation project with a Unet. I have an unbalanced dataset with 2 class and I want to apply, as a first step, a weight for each class. I use the loss torch.nn.BCELoss(). After looking on internet, it seems that people that had a similar problem were advised to switch to BCEWithLogitsLoss() which has a pos_weight argument to choose class weight.

  1. I would prefer if possible to keep using torch.nn.BCELoss() because, from what I understood, if I use BCEWithLogitsLoss() I would have to remove the sigmoid layer at the end of my network which means that in every single place where I make a prediction with my network with code like pred = model(data) I would then need to add a sigmoid to my prediction to get the same output as before. This solution doesn’t seem great as there are many places in my codes where I use pred = model(data).

Have I understood my problem clearly ? Is there any way to keep using BCELoss() ?

Thank you for reading.

I believe you should be able to manually weight the unreduced loss, if you are using binary targets. If that’s not the case, you would need to use nn.BCEWithLogitsLoss with the pos_weight argument. This issue explains the use case a bit more and this code snippet shows the results:

for i in range(100):
    pos_weight = torch.randint(1, 100, (1,)).float()
    
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    x = torch.randn(10, 1, requires_grad=True)
    y = torch.randint(0, 2, (10, 1)).float()
    
    loss = criterion(x, y)
    
    criterion_raw = nn.BCEWithLogitsLoss(reduction='none')
    loss_raw = criterion_raw(x, y)
    weight = torch.ones_like(loss_raw)
    weight[y==1.] = pos_weight
    loss_raw = (loss_raw * weight).mean()
    
    criterion_raw_sig = nn.BCELoss(reduction='none')
    loss_raw_sig = criterion_raw_sig(torch.sigmoid(x), y)
    loss_raw_sig = ((loss_raw_sig) * weight).mean()

    print(loss - loss_raw)    
    print(loss_raw_sig - loss)

Note that nn.BCEWithLogitsLoss increases the numerical stability, so it might still be valuable to switch to this criterion.

1 Like

Thank you for your help, it is really appreciated.

I have switched to using nn.BCEWithLogitsLoss(pos_weight=pos_weight) as it seems to be the easiest solution.

Just to be sure, if I have a dataset repartition of 80% of negative example and 20% of positive example, then pos_weight should be 4 right ?

Have a nice day !

Yes, it would be calculated as nb_neg/nb_pos = 80/20 = 4.

1 Like