"weight" vs. "pos_weight" in nn.BCEWithLogitsLoss()

I’m confused reading the explanation given in the official doc i.e.,

  • pos_weight (Tensor, optional ) – a weight of positive examples. Must be a vector with length equal to the number of classes.

For example, if a dataset contains 100 positive and 300 negative examples of a single class, then pos_weight for the class should be equal to 300/100=3. The loss would act as if the dataset contains 3×100=300 positive examples.

  • weight (Tensor, optional ) – a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.

In the context of multi-label binary classification (i.e., semantic segmentation):

  1. How weight is supposed to be calculated when the ground truth mask is of the shape [N, 1, H, W]?
  2. In what ways both parameters work differently and contribute to the calculation of the BCE loss?
2 Likes

According to the docs, in CEL the weights you have to provide are manual rescaling weights given to each class. This is so the classes with the lowest amount of occurrences will have the same influence on the loss than the classes with the highest amount of occurrences.

Here is a small script that could help you get these weights:

n_classes = 13
total_pxls = np.zeros((n_classes,1))
for img_name in file_list:
    print("image: %s"%img_name)
    new_lbl = cv2.imread(img_name,-1)
    new_lbl = np.array(new_lbl, dtype=np.uint8)
    for k in range(0,n_classes):
        total_pxls[k] = np.sum(new_lbl==k)+total_pxls[k]

no_total_pxls = np.sum(total_pxls[1::])
medians = [total_pxls[k]/no_total_pxls for k in range(1,n_classes)]
median_all = np.mean(medians)
my_freqs = median_all/(np.float64(medians)+(10E-14))
print(['%4.4f'%x for x in np.round(np.float64(my_freqs),3)])#these are the weights

This answer is helpful, but does not cover the full scope of the original question (which I was hoping to receive). What is the situation when you have a binary semantic classification problem (e.g. output prediction size = [B, 1, H, W] and the same for the mask → only one output channel, not 2)?

Hi Cameron!

The short answer is that BCEWithLogitsLoss secretly doesn’t really
have a notion of “classes” inside of it. The weights you pass in just
match up with the elements of your predictions and targets. Whether
you have one or two or more “output channels” (or no “channels”
dimension) is irrelevant – it’s just one more dimension that
BCEWithLogitsLoss processes in the same way.

Some more detail:

(A disclaimer: I might not be entirely right about this. It’s also
conceivable that the generality of the operation of BCEWithLogitsLoss
has increased over time, so my comments might not apply to older
versions of pytorch. I am currently working with version 1.7.1.)

The documentation for BCEWithLogitsLoss is somewhat misleading. It
talks about batches and classes and such, but it just has dimensions;
it doesn’t, per se, operate on batch dimensions and class dimensions,
and so on. The user is allowed to think in those terms, but that
doesn’t affect what BCEWithLogitsLoss actually does.

Consider this example:

import torch
print (torch.__version__)

_ = torch.manual_seed (2021)

nBatch = 2

# consider a "pos_weighted" multi-label, 15-class use case:

nClass = 15

pred_class = torch.randn (nBatch, nClass)
targ_class = torch.rand (nBatch, nClass)
posw_class = torch.rand (nBatch, nClass)

loss_class = torch.nn.BCEWithLogitsLoss (pos_weight = posw_class) (pred_class, targ_class)
print ('loss_class =', loss_class)

# but is there any difference between 15 classes and 3 * 5 image pixels?

height = 3
width = 5

pred_image = pred_class.reshape (nBatch, height, width)
targ_image = targ_class.reshape (nBatch, height, width)
posw_image = posw_class.reshape (nBatch, height, width)

loss_image = torch.nn.BCEWithLogitsLoss (pos_weight = posw_image) (pred_image, targ_image)
print ('loss_image =', loss_image)

# and is the batch dimension treated specially?

loss_flat = torch.nn.BCEWithLogitsLoss (pos_weight = posw_class.flatten()) (pred_class.flatten(), targ_class.flatten())
print ('loss_flat =', loss_flat)

And here is its output:

1.7.1
loss_class = tensor(0.6958)
loss_image = tensor(0.6958)
loss_flat = tensor(0.6958)

The documentation hints at this in its “shape” section:

Shape:

       Input: (N, *) where "*" means, any number of additional dimensions

       Target: (N, *), same shape as the input

       Output: scalar. If reduction is 'none', then (N, *), same shape as input.

But BCEWithLogitsLoss doesn’t actually differentiate between the
batch dimension and the other dimensions. I think this part of the
documentation should simply read:

Shape:

       Input: (*) where "*" means, any number of dimensions

(Use cases with “batches” and “classes” and “channels” and “images”
could be illustrated with examples, but the documentation should state
clearly that there is no operational difference.)

(As far as I can tell, input and target must have the shame shape
and pos_weight must be broadcastable over input and target.
Also weight enters a little differently into the per-element loss
function, but behaves the same as pos_weight in terms of shapes
and dimensions.)

Best.

K. Frank

3 Likes

Thank you for the response!

Ya, I realized after some discussion on a related post that the default return type for pos_weight is a single mean value of the error calculation for every pixel. So what is propagated backwards is a single value. This made a lot more sense to me after.

@Manuel_Alejandro_Dia Hi, I have a doubt about the binary classification problem. In my case, my negative class is the minority class. Would pos_weight work in that case?

Yes!

In this case, your pos_weight will be under 1, since the positive class is the majority one.

Basically, what you want to do in the end, is to use this parameter to balance the addition each class does to the loss.

1 Like

Hi, I have one more doubt. How can address class imbalance using BCELoss, because it does not have pos_weight but weight? What should be the value of weight if the ratio of my labels negative: positive is 20:80?