Multi-Label, Multi-Class class imbalance

Hi, I have implemented a network for multi-label, multi-class classification, this has been done using BCEWithLogits outputting to 6 sigmoid units. However, I have a class imbalance and was wondering if there were a way to weight such classes in the multi-label sense.

I have labels in the following one-hot encoded format: [0,1,0,1,0,0], refers to class 1 and class 3 are present.

The classes have the weighting {0: 0.16 , 1: 0.16 , 2: 0.25 , 3: 0.25 , 4: 0.083, 5: 0.083}.

Any ideas would be great, thank you :).


I think the easiest approach would be to specify reduction='none' in your criterion and then multiply each output with your weights:

target = torch.tensor([[0,1,0,1,0,0]], dtype=torch.float32)
output = torch.randn(1, 6, requires_grad=True)
weights = torch.tensor([0.16, 0.16, 0.25, 0.25, 0.083, 0.083])

criterion = nn.BCEWithLogitsLoss(reduction='none')
loss = criterion(output, target)
loss = (loss * weights).mean()

Would that work for you?


That is brilliant thank you!

1 Like

arn’t we accentuating the loss function for the class that has a lot of examples and over-penalizing the class which has few examples ?
Over-penalizing : it s penalized by default because it has few examples and we can t make our model smooth for it

1 Like

Generally yes, but as I understood from the original post, the target would be just one sample for a multi-label classification task (using 6 classes) with the given weights.


I don’t get the idea of this post. why do we have to make reduction=‘none’ with multi-label problem? to my knowledge we have to multiply by the weight according to the distribution of the dataset. but how do we get the weights? whats the difference between multiplying weights to ‘none’ and multiplying weight to ‘mean’. thanks!

1 Like

You don’t have to use reduction=none, but question in this topic was how to multiply each class prediction with a certain weight, so this is a valid approach.

I’m not sure how @MultiSeedLoaf got the weights, so we might wait for his answer.
A simple approach would be to use the inverse of the class count, but I’m not sure if that’s the best approach for a multi-label classification.

If you use reduction='mean', you’ll get a single scalar, so that you cannot multiply the loss with the class weights anymore.


Hello, @ptrblck

I’m just wondering if I could use “pos_weight” in BCEWithLogitsLoss for imbalanced dataset. Do they (pos_weight and the approach you provided) have the same effect?

1 Like

I have tested the three methods, It’s strange that they got the same result

from torch import nn  
import torch  
class WeightedMultilabel(nn.Module):  
    def __init__(self, weights: torch.Tensor):  
        super(WeightedMultilabel, self).__init__()  
        self.cerition = nn.BCEWithLogitsLoss(reduction='none')  
        self.weights = weights  
    def forward(self, outputs, targets):  
        loss = self.cerition(outputs, targets)  
        return (loss * self.weights).mean()  
# tensor(7.8804)  
# tensor(7.8804)  
# tensor(7.8804)

If you consider each label to be a node in the final layer of the network, “pos_weight” will help in tacking the class imbalance in each node individually (check the formula in documentation). For example in a multi-label binary classification with 4 labels, it will help in assigning weights to the positive class for each label individually. The weighting system mentioned here will help in assigning importance to each of the individual losses calculated for each of the 4 labels. This will not necessarily tackle the class imbalance problem, but I think should scale the gradients from a node.

What does reduction actually mean?

Reduction in this context means how to reduce a tensor with multiple elements to a scalar or generally how to reduce a dimension with a size > 1 to a size of 1.
Loss functions accept the reduction argument, which specifies, if the mean, sum etc. should be used to yield the scalar loss value given a batch of model outputs and targets.

Can this approach be helpful in a multi-label, multiclass class imbalance image segmentation also?

Am I to use the weight or pos weight parameter in the nn.BCEWithLogitsLoss because it will be combined with a weighted dice loss function?

My approach of getting the class weights is by summing the 1’s in each class label of the tragets of the training dataset and taking the inverse as suggested.

1 Like

Yes, I would think so.

Calculating the weights in my code snippet would add the corresponding weights to each sample of the loss before the reduction is done.
The pos_weight acts directly on the positive term of the loss function as shown in the docs, so they wouldn’t be equivalent.

1 Like

Thanks @ptrblck

Is this approach in the case of multiclass multilabel?

If you are using these calculated weights, I would start with the proposed calculation from the docs:

For example, if a dataset contains 100 positive and 300 negative examples of a single class, then pos_weight for the class should be equal to 300/100=3 . The loss would act as if the dataset contains 3×100=300 positive examples.

On the other hand, if you want to weight each sample in the batch, I would normalize the weights to a sum of 1 so that each batch loss has the same mean.

1 Like

The proposed weight calculation is used for the pos_weights but incase of other weight calculations as suggested in the forum, do I take such for the pos_weight or the wights in the BCEWithLogistLoss?

1 Like

Just wanted to say that overall your comments are simple and right to the point … Tnx!!

1 Like

Hi @ptrblck

Actually, I have not completely figured out the answers to moreshud ’ s question.

Imagine that I have a multi-class, multi-label classification problem; my imbalanced one-hot coded dataset includes 1000 images with 4 labels with the following frequencies: class 0: 600, class 1: 550, class 2: 200, class 3: 100.
As I said, the targets are in a one-hot coded structure.
For instance, the target [0, 1, 1, 0] means that classes 1 and 2 are present in the corresponding image.

Well, in order to calculate the BCEWithLogitsLoss concerning the data imbalance, one way is that you suggested in this post: Multi-Label, Multi-Class class imbalance - #2 by ptrblck. Here, we calculate the class weights by inverting the frequencies of each class, i.e., the class weight tensor in my example would be: torch.tensor ([1/600, 1/550, 1/200, 1/100]). After that, the class weight tensor will be multiplied by the unreduced loss and the final loss would be the mean of this tensor.

However, as far as I know, the pos_weight parameter of the BCEWithLogitsLoSS could also be used in this case. Here is my question:
to my knowledge, the two tensors (class weight tensor in the previous paragraph and the pos_weight tensor) are totally different. For each class, the number of positive and negative samples should be calculated and the num_negative/ num_positive would be the pos_weight for that particular class. In my example, the pos_weight of class 0 would be (1000-600)/600 = 0.67, the pos_weight of class 1 would be (1000-550)/550 = .82, the pos_weight of class 2 would be (1000-200)/200 = 4, and the pos_weight of class 3 would be (1000-100)/100 = 9. Thus, the pos_weight tensor would be torch.tensor ([.67, .82, 4, 9]). Is this way of calculating pos_weight tensor the right one? If the answer is yes, I think that the previous method (I mean calculating the class weights and multiplying it with the unreduced loss) would be more convenient in the case of a dataset with a large number of labels since we should only invert the frequencies, am I right?

Also, another question is about the weight parameter of the BCEWithLogitsLoss. As represented in the formula, the weight parameter is a tensor that is multiplied by the whole loss, not merely the positive targets (as opposed to pos_weight). My question is that how is the weight parameter tensor different from the class weights tensor? since the class weights tensor is similarly multiplied by the whole loss. However, it is said that the weight parameter tensor is of size nbatch, and I do not understand what its function is.

I deeply appreciate your consideration.

I have the similar question. Have you found the answer?