Dealing with imbalanced datasets in pytorch

Sounds reasonable. You could probably also try sum(nr_samples_per_label).
I’m not sure if there is a general rule of thumb as you might want to balance your per-class accuracies manually.


Shouldn’t the weights be like:

weight_label_i =  nr_samples_of_label_i  / total_number_of_samples

i = 1, 2, ...


This would weight the majority classes higher, while we would like to weight the loss of minority classes higher or am I mistaken?

1 Like

You are right. The BCELoss is given by:
ℓ(x,y)=L={l1,…, ln,....,lN}⊤, ln = −wn[yn⋅logxn+(1−yn)⋅log(1−xn)]

in this case we invert the above:
weight_label_i = total_number_of_samples / nr_samples_of_label_i

The documentation also says:

weight (Tensor, optional) – a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size “nbatch”

So, the weight vector works on the batch level and not the per-class samples level, or, I am missing something here?

1 Like

Yeah, you are right! Thanks for pointing this out. I tried not to mix up both new threads about weighting and thought we are dealing with some classification loss like nn.NLLLoss.

@Skinish have a look at this thread to see how to apply pos_weight instead.


I am not sure of the pos_weight will do, even BCELoss has this pos_weight input argument, and this does not seem to fulfill what is needed (I could be mistaken). The way I see it to resolve this issue in a simple manner is to dot product the weight vector by the output and target/label-vector, during finding the loss of the training, as follows:

criterion = nn.BCEWithLogitsLoss()
loss = criterion(W*output.float(), W*target.float())

For more flexibility and ease, the (balance) weight vector could be generated within the dataset class.

I was trying to implement this weight multiplication, not sure if this is the best way to do it as I had to use torch.transpose twice, here it goes:

# 10 is the batch size, so each sample has a weight value

torch.Size([10, 644])
tensor([ 0.9987,  0.9997,  0.9997,  0.9992,  0.9997,  0.9985,  0.9905,
         0.9911,  0.9476,  0.9944], dtype=torch.float64, device='cuda:0')

ipdb> output = torch.mul(weight, torch.transpose(output, 0, 1) )
ipdb> output = torch.transpose(output, 0,1)
torch.Size([644, 10])

NB. I have added the weights to the Dataset class

The whole thing will look like this:

if cf.use_weight_to_balance_data:
       weight =
       output = torch.mul(weight, torch.transpose(output.double(), 0, 1) )
       output = torch.transpose(output, 0, 1)
       target = torch.mul(weight, torch.transpose(target.double(), 0, 1) )
       target = torch.transpose(target, 0, 1)

@Deeply @ptrblck thank you for all your help. Should the weight multiplication be the viable solution? I did not quite understand what would be wrong with passing the weights vector to the loss function (to the pos_weight argument), although the weight assignment that I said should be changed.

What would the effect of passing a vector to weight then?

@Deeply I’m not sure it’s a good idea to multiply the output and target directly with the weight.
Both will pass a criterion and maybe a sigmoid, so I would rather multiply the loss.

@Skinish here is a github issue discussing the introduction of pos_weight and a comparison between weight and pos_weight.
As you have different criteria for your targets and predictions, I think using pos_weight would work.
Let me know, if that works for you.


Hi guys, recently I played a lot with:

  • Weighted Semantic segmentation
  • Imbalanced data (Google Open Images)

What worked for me:

  • Loss / mask weighting - showed a lot of improvement. Below is my Loss, and here is the result description
import torch
import torch.nn as nn
import torch.nn.functional as F

class SemsegLossWeighted(nn.Module):
    def __init__(self,

        self.use_weight_mask = use_weight_mask
        self.nll_loss = nn.BCEWithLogitsLoss()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.eps = eps
        self.gamma = gamma 
        self.use_running_mean = use_running_mean
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.deduct_intersection = deduct_intersection
        if self.use_running_mean == True:
            self.register_buffer('running_bce_loss', torch.zeros(1))
            self.register_buffer('running_dice_loss', torch.zeros(1))

    def reset_parameters(self):

    def forward(self,
        # inputs and targets are assumed to be BxCxWxH
        assert len(outputs.shape) == len(targets.shape)
        # assert that B, W and H are the same
        assert outputs.size(0) == targets.size(0)
        assert outputs.size(2) == targets.size(2)
        assert outputs.size(3) == targets.size(3)
        # weights are assumed to be BxWxH
        # assert that B, W and H are the are the same for target and mask
        assert outputs.size(0) == weights.size(0)
        assert outputs.size(1) == weights.size(1)
        assert outputs.size(2) == weights.size(2)
        assert outputs.size(3) == weights.size(3)
        if self.use_weight_mask:
            bce_loss = F.binary_cross_entropy_with_logits(input=outputs,
            bce_loss = self.nll_loss(input=outputs,

        dice_target = (targets == 1).float()
        dice_output = F.sigmoid(outputs)
        intersection = (dice_output * dice_target).sum()
        if self.deduct_intersection:
            union = dice_output.sum() + dice_target.sum() - intersection + self.eps
            union = dice_output.sum() + dice_target.sum() + self.eps
        dice_loss = (-torch.log(2 * intersection / union))         
        if self.use_running_mean == False:
            bmw = self.bce_weight
            dmw = self.dice_weight
            # loss += torch.clamp(1 - torch.log(2 * intersection / union),0,100)  * self.dice_weight
            self.running_bce_loss = self.running_bce_loss * self.gamma + * (1 - self.gamma)        
            self.running_dice_loss = self.running_dice_loss * self.gamma + * (1 - self.gamma)

            bm = float(self.running_bce_loss)
            dm = float(self.running_dice_loss)

            bmw = 1 - bm / (bm + dm)
            dmw = 1 - dm / (bm + dm)
        loss = bce_loss * bmw + dice_loss * dmw
        return loss,bce_loss,dice_loss    

  • Over / under sampling and / or sampling (link) - worked technically, but no accuracy boost
  • Analyzing the internal structure of data and building a cascade of models

Hope this is helpful.


You are absolutely correct!
In such case, it is better to use BCELoss instead of BCEWithLogitsLoss, hence, we need to apply the sigmoid on the output before multiplying it by the weight.
Or, if using the BCEWithLogitsLoss( reduce = ‘none’), then, multiplying the weight by the loss and taking the mean will doe, as follows:

if cf.use_weight_to_balance_data:
       weight =
       loss = criterion(output, target)
       loss = torch.mul(weight, torch.transpose(loss.double(), 0, 1) )               
       loss= torch.mean(loss)

As for pos_weight, the documentation says this:
where pn is the positive weight of class n. pn>1 increases the recall, pn<1 increases the precision.

For example, if a dataset contains 100 positive and 300 negative examples of a single class, then pos_weight for the class should be equal to 300100=3. The loss would act as if the dataset contains math: 3times 100=300 positive examples.

Thus, I don’t think it would help in balancing the data, to give an example on using the pos_weight, I imagine building a classifier to classify m diseases, now, each disease has stats on the positive and negative cases that can be used to estimate the pos_weight and use it in the analysis.

My latest method above worked well, but there was not any improvement, so in my case maybe the balance has no effect. However, when I try to use:

criterion = nn.BCEWithLogitsLoss()
loss = criterion( output, target, weight=weight )

I had an error saying: got an unexpected value weight
So, I took the weight value to the class constructor (which I had to put inside the batch loop), something like:

criterion = nn.BCEWithLogitsLoss(weight=weight)
loss = criterion( output, target )

which also gave an error saying:
RuntimeError: The size of tensor a (644) must match the size of tensor b (10) at non-singleton dimension 1

Thus, I am not sure if F.binary_cross_entropy_with_logit is different from nn.BCEWithLogitsLoss and that’s why my code is not running?!

Thus, I am not sure if F.binary_cross_entropy_with_logit is different from nn.BCEWithLogitsLoss and that’s why my code is not running?!

As far as I know both of these methods are mostly the same, but the difference is in the way weight is parametrized.
As far as I see it, the docs say that the weight will be broadcased, but in my case either of these approaches worked with F.binary_cross_entropy_with_logits (if I remember correctly):

  • Make your weights be WxH
  • Or make your weights be BxCxWxH
  • If you try BxWxH or CxWxH - I guess there will be an error
1 Like

Thank you for your feedback. Could you please explain further what kind of loss weighting you did in here? By that I mean, what were the weights that you used? And what is the main difference of a F.binary_cross_entropy_with_logits with a weightargument vs nn.BCEWithLogitsLoss with weight / pos_weight argument?

For what I see, by applying pos_weight in BCEWithLogitsLoss loss, the total loss is indeed getting higher, which is what was intended, but the results are the same, or even worse actually. Maybe the loss becomes harder to minimize?

If the overall loss increases, I would try to lower the learning rate to help the model converge.

I am using Adam so it should not make much difference :confused:

Probably, but it’s still worth a try :wink:


I am trying to deal with imbalanced data. Based on what I read in discussions above,
nr_samples_of_label_i / total_number_of_samples

For instance on 250000 samples, one of the imbalanced classes contains 150000 samples:
150000 / 250000 = 0.6
One of the underrepresented classes:
20000/250000 = 0.08

So to reduce the impact of the overrepresented imbalanced class, I multiply the loss with 1 - 0.6 = 0.4
To increase the impact of the underrepresented class, 1 - 0.08 = 0.92

Is that an acceptable way of working?