How to apply weighted loss to a binary segmentation problem?

Hi ,

I have a binary segmentation problem. Where the label/target tensor is a simple binary mask where the background is represented by 0 and the foreground (object I want to segment) by 1.

I read that for such problems people have gotten great results using a single channel output, so the output from my U-Net network is of the shape [1,1,30,256,256]. Similarly, the target tensor has the same shape.

The objects I want to segment are very small so in every sample I have a lot of background but only a few foreground pixels. And I was wondering if applying weights to this would help.

I know that when I have multiple objects to segment such as if I had three classes [0,1,2] then I could apply weights like: [0.2,0.2,0.6]. But how to do the same when I have only 1 class?

Is the network considering the background as a class even though I chose to output a single channel? or is the background being ignored. If the background is not being ignored then how can I know how to set the weights? Would it be [0.2, 0.8] or [0.8. 0.2]?

Thank you

I assume you are using nn.BCEWithLogitsLoss or nn.BCELoss as your criterion for this architecture.
In the former case, you could specify the pos_weight argument, where your valid class (class1) would act as the positive class.

1 Like

Thanks for replying @ptrblck

However, in trying to understand how to implement this I came across your post here: About BCEWithLogitsLoss's pos_weights

I applied pos_weight accordingly as an int but I am given an error that it does not expect to receive an int. Searching further lead me here: Weights in BCEWithLogitsLoss

So I am back where I started, how do I apply this weight? or if it is expecting a vector, how do I apply the weights in my binary case?

Thanks

You could just pass a torch.FloatTensor for the positive weight:

output = torch.randn(1, 1, 10, 10, requires_grad=True)
target = torch.randint(0, 2, (1, 1, 10, 10)).float()
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(10.))
loss = criterion(output, target)

Let me know, if that works for you!

4 Likes

Could this be extended to cross-entropy loss as well ? If yes, do I need to give weights for both classes ?

nn.CrossEntropyLoss accepts a weight argument with a weight value for each class.
Would this work for you or did I misunderstand the question?

Thats about right. I meant will this work for binary segmentation problem as mentioned by the OP ?

Yes, if your model outputs two logits, you could apply it on a binary classification use case.

Yes, my outputs are something like [batch, 2, 224, 224] where 2 is the number of classes.

Hi, I am also building a BERT model with a binary classifier. So my output logits size is [nbatch, 1].
Since my sample is really imbalanced: for 1000 cases, positive (class 1) vs. negative (class 0) portion will be [0.101521, 0.898479].
When i am searching around about applying weights for nn.BCEWithLogitsLoss. I came across your another reply:


Where you were using this method:

weight = torch.tensor([0.1, 0.9])
weight_ = weight[y.data.view(-1).long()].view_as(y)
criterion = nn.BCELoss(reduce=False)
loss = criterion(output, y)
loss_class_weighted = loss * weight_
loss_class_weighted = loss_class_weighted.mean()

I modified it a little bit to match mine case:

weight = torch.tensor([0.101521, 0.898479])  # hard code from entire training dataset
pos_weight = weight[labels.data.view(-1).long()].view_as(labels)
loss_fct = nn.BCEWithLogitsLoss(pos_weight=pos_weight).cuda() # have to add the .cuda() to avoid an error
loss = loss_fct(logits, labels)

Since my batch size is 16, the pos_weight here has size 16. The model is running, but it didn’t improve my recall rate for the positive cases so far.
Then I came across your this post. Where your example make weight a single input: 10.
Seems like if i only pass one value instead of nbatch for weights, the model runs too. But i am confuse which one should i follow? Should weights size be nbatch or just one number?

Also, I want to give some example results from models. If I set epochs = 2 (When epochs = 1, my recall for positive case is always 0 on validation step).
If i don’t apply weights for the loss, my first epoch:

2020-09-23 00:36:17,114 - utils - INFO - ======== Epoch 1 / 2 ========
2020-09-23 00:36:17,114 - utils - INFO - Training...
2020-09-23 00:37:13,784 - utils - INFO - | epoch   1 |   100/ 1320 batches | lr 2.886e-05 | loss 0.354 | Elapsed 0:00:56
2020-09-23 00:38:10,029 - utils - INFO - | epoch   1 |   200/ 1320 batches | lr 2.773e-05 | loss 0.320 | Elapsed 0:01:52
2020-09-23 00:39:06,261 - utils - INFO - | epoch   1 |   300/ 1320 batches | lr 2.659e-05 | loss 0.305 | Elapsed 0:02:49
2020-09-23 00:40:03,269 - utils - INFO - | epoch   1 |   400/ 1320 batches | lr 2.545e-05 | loss 0.319 | Elapsed 0:03:46
2020-09-23 00:41:00,155 - utils - INFO - | epoch   1 |   500/ 1320 batches | lr 2.432e-05 | loss 0.288 | Elapsed 0:04:43
2020-09-23 00:41:57,041 - utils - INFO - | epoch   1 |   600/ 1320 batches | lr 2.318e-05 | loss 0.290 | Elapsed 0:05:39
2020-09-23 00:42:54,411 - utils - INFO - | epoch   1 |   700/ 1320 batches | lr 2.205e-05 | loss 0.317 | Elapsed 0:06:37
2020-09-23 00:43:51,109 - utils - INFO - | epoch   1 |   800/ 1320 batches | lr 2.091e-05 | loss 0.316 | Elapsed 0:07:33
2020-09-23 00:44:47,656 - utils - INFO - | epoch   1 |   900/ 1320 batches | lr 1.977e-05 | loss 0.278 | Elapsed 0:08:30
2020-09-23 00:45:45,236 - utils - INFO - | epoch   1 |  1000/ 1320 batches | lr 1.864e-05 | loss 0.285 | Elapsed 0:09:28
2020-09-23 00:46:41,736 - utils - INFO - | epoch   1 |  1100/ 1320 batches | lr 1.750e-05 | loss 0.260 | Elapsed 0:10:24
2020-09-23 00:47:39,109 - utils - INFO - | epoch   1 |  1200/ 1320 batches | lr 1.636e-05 | loss 0.316 | Elapsed 0:11:21
2020-09-23 00:48:36,184 - utils - INFO - | epoch   1 |  1300/ 1320 batches | lr 1.523e-05 | loss 0.260 | Elapsed 0:12:19
2020-09-23 00:48:47,433 - utils - INFO - 
2020-09-23 00:48:47,434 - utils - INFO -   Training epoch took: 0:12:30
2020-09-23 00:48:47,434 - utils - INFO - 
2020-09-23 00:48:47,434 - utils - INFO - Validating...
2020-09-23 00:49:21,259 - utils - INFO - 
2020-09-23 00:49:21,259 - utils - INFO - | loss 0.019 | Elapsed 0:00:33
2020-09-23 00:49:21,267 - utils - INFO -   
              precision    recall  f1-score      support
0.0            0.893408  1.000000  0.943703  2087.000000
1.0            0.000000  0.000000  0.000000   249.000000
accuracy       0.893408  0.893408  0.893408     0.893408
macro avg      0.446704  0.500000  0.471852  2336.000000
weighted avg   0.798177  0.893408  0.843112  2336.000000

If i use the first method: I pass weights with nbatch size when calculating the loss

2020-09-23 19:15:53,890 - utils - INFO - ======== Epoch 1 / 2 ========
2020-09-23 19:15:53,890 - utils - INFO - Training...
2020-09-23 19:16:50,893 - utils - INFO - | epoch   1 |   100/ 1320 batches | lr 2.886e-05 | loss 0.330 | Elapsed 0:00:57
2020-09-23 19:17:47,841 - utils - INFO - | epoch   1 |   200/ 1320 batches | lr 2.773e-05 | loss 0.299 | Elapsed 0:01:53
2020-09-23 19:18:43,978 - utils - INFO - | epoch   1 |   300/ 1320 batches | lr 2.659e-05 | loss 0.284 | Elapsed 0:02:50
2020-09-23 19:19:41,488 - utils - INFO - | epoch   1 |   400/ 1320 batches | lr 2.545e-05 | loss 0.296 | Elapsed 0:03:47
2020-09-23 19:20:38,118 - utils - INFO - | epoch   1 |   500/ 1320 batches | lr 2.432e-05 | loss 0.272 | Elapsed 0:04:44
2020-09-23 19:21:35,336 - utils - INFO - | epoch   1 |   600/ 1320 batches | lr 2.318e-05 | loss 0.271 | Elapsed 0:05:41
2020-09-23 19:22:32,712 - utils - INFO - | epoch   1 |   700/ 1320 batches | lr 2.205e-05 | loss 0.290 | Elapsed 0:06:38
2020-09-23 19:23:29,436 - utils - INFO - | epoch   1 |   800/ 1320 batches | lr 2.091e-05 | loss 0.291 | Elapsed 0:07:35
2020-09-23 19:24:26,124 - utils - INFO - | epoch   1 |   900/ 1320 batches | lr 1.977e-05 | loss 0.258 | Elapsed 0:08:32
2020-09-23 19:25:23,942 - utils - INFO - | epoch   1 |  1000/ 1320 batches | lr 1.864e-05 | loss 0.265 | Elapsed 0:09:30
2020-09-23 19:26:20,755 - utils - INFO - | epoch   1 |  1100/ 1320 batches | lr 1.750e-05 | loss 0.253 | Elapsed 0:10:26
2020-09-23 19:27:18,639 - utils - INFO - | epoch   1 |  1200/ 1320 batches | lr 1.636e-05 | loss 0.296 | Elapsed 0:11:24
2020-09-23 19:28:15,900 - utils - INFO - | epoch   1 |  1300/ 1320 batches | lr 1.523e-05 | loss 0.243 | Elapsed 0:12:22
2020-09-23 19:28:27,048 - utils - INFO - 
2020-09-23 19:28:27,048 - utils - INFO -   Training epoch took: 0:12:33
2020-09-23 19:28:27,048 - utils - INFO - 
2020-09-23 19:28:27,048 - utils - INFO - Validating...
2020-09-23 19:29:03,972 - utils - INFO - 
2020-09-23 19:29:03,972 - utils - INFO - | loss 0.017 | Elapsed 0:00:36
2020-09-23 19:29:03,980 - utils - INFO -   
              precision    recall  f1-score      support
0.0            0.893408  1.000000  0.943703  2087.000000
1.0            0.000000  0.000000  0.000000   249.000000
accuracy       0.893408  0.893408  0.893408     0.893408
macro avg      0.446704  0.500000  0.471852  2336.000000
weighted avg   0.798177  0.893408  0.843112  2336.000000

If I pass the weights with the second method: only one value:

2020-09-23 19:51:56,192 - utils - INFO - ======== Epoch 1 / 2 ========
2020-09-23 19:51:56,192 - utils - INFO - Training...
2020-09-23 19:52:53,044 - utils - INFO - | epoch   1 |   100/ 1320 batches | lr 2.886e-05 | loss 1.290 | Elapsed 0:00:56
2020-09-23 19:53:49,731 - utils - INFO - | epoch   1 |   200/ 1320 batches | lr 2.773e-05 | loss 1.212 | Elapsed 0:01:53
2020-09-23 19:54:45,798 - utils - INFO - | epoch   1 |   300/ 1320 batches | lr 2.659e-05 | loss 1.153 | Elapsed 0:02:49
2020-09-23 19:55:43,080 - utils - INFO - | epoch   1 |   400/ 1320 batches | lr 2.545e-05 | loss 1.128 | Elapsed 0:03:46
2020-09-23 19:56:39,890 - utils - INFO - | epoch   1 |   500/ 1320 batches | lr 2.432e-05 | loss 1.150 | Elapsed 0:04:43
2020-09-23 19:57:36,857 - utils - INFO - | epoch   1 |   600/ 1320 batches | lr 2.318e-05 | loss 1.102 | Elapsed 0:05:40
2020-09-23 19:58:33,921 - utils - INFO - | epoch   1 |   700/ 1320 batches | lr 2.205e-05 | loss 1.146 | Elapsed 0:06:37
2020-09-23 19:59:30,677 - utils - INFO - | epoch   1 |   800/ 1320 batches | lr 2.091e-05 | loss 1.154 | Elapsed 0:07:34
2020-09-23 20:00:27,239 - utils - INFO - | epoch   1 |   900/ 1320 batches | lr 1.977e-05 | loss 1.062 | Elapsed 0:08:31
2020-09-23 20:01:24,611 - utils - INFO - | epoch   1 |  1000/ 1320 batches | lr 1.864e-05 | loss 1.081 | Elapsed 0:09:28
2020-09-23 20:02:21,271 - utils - INFO - | epoch   1 |  1100/ 1320 batches | lr 1.750e-05 | loss 1.075 | Elapsed 0:10:25
2020-09-23 20:03:18,690 - utils - INFO - | epoch   1 |  1200/ 1320 batches | lr 1.636e-05 | loss 1.198 | Elapsed 0:11:22
2020-09-23 20:04:15,625 - utils - INFO - | epoch   1 |  1300/ 1320 batches | lr 1.523e-05 | loss 1.015 | Elapsed 0:12:19
2020-09-23 20:04:26,809 - utils - INFO - 
2020-09-23 20:04:26,809 - utils - INFO -   Training epoch took: 0:12:30
2020-09-23 20:04:26,809 - utils - INFO - 
2020-09-23 20:04:26,809 - utils - INFO - Validating...
2020-09-23 20:05:00,461 - utils - INFO - 
2020-09-23 20:05:00,461 - utils - INFO - | loss 0.065 | Elapsed 0:00:33
2020-09-23 20:05:00,468 - utils - INFO -   
              precision    recall  f1-score      support
0.0            0.943584  0.817441  0.875995  2087.000000
1.0            0.278409  0.590361  0.378378   249.000000
accuracy       0.793236  0.793236  0.793236     0.793236
macro avg      0.610997  0.703901  0.627187  2336.000000
weighted avg   0.872681  0.793236  0.822953  2336.000000

As you can see, with the second weights method, the loss per 100 batches goes above 1. However, even with first epoch, my recall rate is much higher.

After finishing 2 epochs, weights version 1 has a slightly higher increase on precision (5%), while recall stay same ~10%, comparing to unweighted one.
With the second weights version, my recall has been improved to ~40% while the precision also decreased from 90% to 33%.

Summary of my questions:

  1. Can you explain what’s the difference between the first method and the second weighting method?
  2. The goal of my model is to have a high precision for the positive case, however, i also want to make sure the recall rate is not too low. If both weighting methods work, what direction should I go?
  3. Do you have an idea of why using unweighted loss or weight version 1, my first epoch always have recall rate of 0?

Thank you so much in advance!

The first method was posted in 2018 where pos_weight was most likely not implemented, so don’t need to use the manual approach anymore by multiplying the weight tensor to the unreduced loss.
From the docs for nn.BCEWithLogitsLoss:

For example, if a dataset contains 100 positive and 300 negative examples of a single class, then pos_weight for the class should be equal to 300/100=3. The loss would act as if the dataset contains 3* 100=300 positive examples.

So for your binary classification use case you should pass a tensor with a single element to the criterion.