Weights in BCEWithLogitsLoss

I guess what’s missing from the document is that the last dim is the dim of M classes? The other dimensions of the score tensor are H, W, D… (can I call them geometric dimensions?) But in the case of binary classification, the last dim is also a geometric dimension. That’s why the use of BCEWithLogitsLoss is confusing.

Dear ptrblck,
Should the $pos_weight$ be computed for every batch or should use the statistic num_neg / num_pos in the whole dataset? Would u give any recomendation?
Thanks!

The po_weight is usually computed for the complete training dataset and passed during the instantiation of the criterion.

1 Like

Unfortunately, all answers are ambiguous and uselss.

5 Likes

Hello, sorry to kickstart this thread again, but I tried using the pos_weight as described here and it doesn’t seem to be working for me. I triple checked everything and I can’t seem to find the issue. Can you confirm if my understanding is correct?

I have a binary classification problem and am using my own implementation of a U-Net (fancy CNN with a decoder rather than a fully connected layer [1505.04597] U-Net: Convolutional Networks for Biomedical Image Segmentation).

I decided to have the output prediction be of size [B, 1, 192, 192] (B = batch size) and just interpret the result as > 0.5 = class 1, else class 0 where class 1 is the class I want to predict. My ground truth data also exists as [B, 1, 192, 192] when I feed it (and the prediction) into the BCEWithLogitsLoss function, then do loss.backward() and optimizer.step(). I am also zeroing the gradient before each batch is received.

Questions:

  1. Is class 1 considered the “positive class”?
  2. If yes, how does the system know that class “0” is the negative class (which is only something I interpret)? I would have thought I needed 2 channels for this rather than 1
  3. As I am working with single values at each x, y index of the 192x192 prediction, does the pos_weight parameter even work?

Can someone tell me if my understanding (based on what I said above) is flawed?

Thank you in advance.

1 Like
  1. Yes, that’s the default interpretation. You could of course interpret the values as you wish and redefine the positive and negative, as well as the recall, precision etc.

  2. By definition 0s would be the negative class. The pos_weight usage is shown in the formula in the docs so you can re-interpret it if needed (see point 1). No, binary classification/segmentation with nn.BCEWithLogitsLoss expect a single output channel for the binary classes. Multi-label classification/segmentation (where each sample/pixel can belong to zero, one, or more classes) expect an output channel per class.

  3. Yes, as seen here:

x = torch.randn(2, 1, 24, 24, requires_grad=True)
y = torch.zeros(2 * 1 * 24 * 24)
y[torch.randint(0, y.nelement(), (100,))] = 1.
y = y.view_as(x)
print('y: 0s: {}, 1s: {}, nelement: {}'.format(
    (y==0.).sum(), y.sum(), y.nelement()))


criterion = nn.BCEWithLogitsLoss()
loss = criterion(x, y)
print(loss)

criterion_weighted = nn.BCEWithLogitsLoss(pos_weight=(y==0.).sum()/y.sum())
loss_weighted = criterion_weighted(x, y)
print(loss_weighted)
3 Likes

Wow! That was a quick response! Thank you!

I wasn’t expecting a response so late in the night (it’s 1:30 AM in my time zone).

I will inspect this code and post another reply should I have any other questions.

Thank you again!

Hello, as promised, I have some more questions lol:

  1. I see that the loss value has indeed increased when applying the pos_weight. I am struggling to understand the -Wn value and what the sigma represents in the formula here. I think the sigma is actually the sigmoid calculation 1/(1+e^(-x)).
  2. Am I to use the 1st or second formula for my situation using pos_weight?
  3. The result from calling loss = criterion(x, y) yields a Tensor with 1 value in it. I would have thought this would return a tensor that was the same size as the supplied tensors so the modification can be made to each pixel location during backprop. Is this single value applied to every value in the predicted tensors?
  4. Would I be able to use the weight parameter rather than the pos_weight in this instance (to actively decrement the importance of class 0)?

Thank you again. Sorry for my ignorance. I am just getting started with PyTorch

  1. The w_n value is defined by the weight parameter of the criterion and the sigma is indeed the sigmoid function:
  • weight (Tensor, optional) – a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.
  1. The second one, since p_c defines the pos_weight.

  2. By default reduction='mean' will be used. If you want to get the unreduced loss tensor, you could use reduction='none'.

  3. The weight parameter would weight each sample so I don’t think it would yield the same results, since w_n would be applied to both terms.

1 Like

Awesome! Thank you very much!

And I didn’t realize that “reduction” was referring to either returning the mean of the entire tensor (or tensor). Thank you for informing me.

I tried the weighted BCEWithLogitsLoss with following:
input = (B,1,256,256)
model_output = (B,3,256,256)
Grd_truth = (B,3,256,256)

initially I set the pos_weight as a Tensor size of 3, but it was showing an error . Then I changed it to the size of (3,256,256) and it worked (better than w/o weights)

For calculating pos_weight we need to use only train_dataset or the entire dataset for multilabel classification?

1 Like

Only the train dataset.

1 Like

How can we balance the binary outputs (i.e., weight all the classes)?
For example, a multi-label classification [dog, cat, rabbit] and BCEWithLogitsLoss is used, how can we weight the importance among the three classes, dog, cat, rabbit?

pos_weight expects a tensor with the length equal to the number of classes for multi-label use cases, so you could provide a separate weight for each class.

Sorry for my confusion. In my practice, weight also expects a tensor with the length equal to the number of classes for multi-label use cases rather than the size of sample batch.
In my understanding, `pos_weight control the positive/negative balance within each class. But I’m going to control the balance between the classes.
Hoping for your reply, thank you!

In this case you might want to use an unreduced loss and apply the class weights afterwards.
Since you are working on a multi-label classification I’m unsure what kind of weights you are planning to apply for samples with multiple positive class labels.

1 Like

It seems that the roles of weight differs between multi-label classification and multi-class single-label classification. In the first case, weight should have the same length as label categories. In the second case, weight should have the same size as nbatch.

Yes, since it might not be trivial to apply weights to a multi-label classification use case.
Let me give you an example.
In a multi-class classification you can directly apply a class weight to the corresponding sample as seen here:

# multi-class classification
batch_size = 10
nb_classes = 4
logits = torch.randn(batch_size, nb_classes, requires_grad=True)
targets = torch.randint(0, nb_classes, (batch_size,))
weights = torch.rand(nb_classes)

print(targets)
# tensor([2, 3, 3, 0, 1, 2, 3, 2, 0, 1])
print(weights)
# tensor([0.9253, 0.1432, 0.8336, 0.9465])

weighted_criterion = nn.CrossEntropyLoss(weight=weights, reduction="mean")
loss = weighted_criterion(logits, targets)
print(loss)
# tensor(2.6470, grad_fn=<NllLossBackward0>)

raw_criterion = nn.CrossEntropyLoss(reduction="none")
loss_raw = raw_criterion(logits, targets)
print(loss_raw)
# tensor([2.3437, 3.4518, 3.3348, 2.1393, 0.9009, 5.2935, 1.5514, 0.5305, 2.6735,
#         3.5571], grad_fn=<NllLossBackward0>)
loss_weighted = (loss_raw * weights[targets] / weights[targets].sum()).sum()
print(loss_weighted)
# tensor(2.6470, grad_fn=<SumBackward0>)

Indexing the weights tensor with the targets works fine and returns the expected loss as verified in my manual comparison.
However, in a multi-label classification use case each sample can belong to zero, one, or multiple classes as seen here:

# multi-label classification
targets = torch.randint(0, 2, (batch_size, nb_classes))
print(targets)
# tensor([[1, 0, 1, 1],
#         [1, 1, 1, 1],
#         [0, 0, 1, 1],
#         [1, 1, 1, 1],
#         [0, 1, 1, 0],
#         [1, 0, 0, 0],
#         [0, 0, 0, 1],
#         [1, 0, 0, 1],
#         [1, 0, 0, 1],
#         [1, 0, 0, 0]])

raw_criterion = nn.BCEWithLogitsLoss(reduction="none")
loss_raw = raw_criterion(logits, targets.float())
print(loss_raw)
# tensor([[0.7298, 0.3813, 1.4581, 0.5213],
#         [1.4387, 0.5741, 0.6803, 2.5363],
#         [0.3412, 0.3628, 0.1299, 1.4728],
#         [0.9664, 1.2701, 0.2159, 2.8560],
#         [0.6312, 0.4537, 0.6042, 0.3793],
#         [0.2912, 2.0887, 0.0754, 1.8689],
#         [1.0385, 2.0379, 0.8648, 0.3196],
#         [0.7231, 0.1918, 1.3323, 0.8123],
#         [0.8825, 0.9056, 1.9396, 0.3908],
#         [0.7225, 0.3262, 0.3110, 2.5515]],
#        grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

print(weights)
# tensor([0.9253, 0.1432, 0.8336, 0.9465])

It’s now unclear to me how you would like to apply the weights.
E.g. take the first sample with a target of [1, 0, 1, 1], which means classes 0, 2, and 3 are “active”.
Would you sum the corresponding weights and multiply it directly with the unreduced loss?

Yes, I would like to do something like 0.9253 * 0.7298+ 0.1432 * 0.3813 + 0.8336 * 1.4581 + 0.9465 * 0.5213 according to your example.