Dealing with imbalanced datasets in pytorch

Hi guys, recently I played a lot with:

  • Weighted Semantic segmentation
  • Imbalanced data (Google Open Images)

What worked for me:

  • Loss / mask weighting - showed a lot of improvement. Below is my Loss, and here is the result description
import torch
import torch.nn as nn
import torch.nn.functional as F

class SemsegLossWeighted(nn.Module):
    def __init__(self,
                 use_running_mean=False,
                 bce_weight=1,
                 dice_weight=1,
                 eps=1e-10,
                 gamma=0.9,
                 use_weight_mask=False,
                 deduct_intersection=False
                 ):
        super().__init__()

        self.use_weight_mask = use_weight_mask
        
        self.nll_loss = nn.BCEWithLogitsLoss()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.eps = eps
        self.gamma = gamma 
        
        self.use_running_mean = use_running_mean
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.deduct_intersection = deduct_intersection
        
        if self.use_running_mean == True:
            self.register_buffer('running_bce_loss', torch.zeros(1))
            self.register_buffer('running_dice_loss', torch.zeros(1))
            self.reset_parameters()

    def reset_parameters(self):
        self.running_bce_loss.zero_()        
        self.running_dice_loss.zero_()            

    def forward(self,
                outputs,
                targets,
                weights):
        # inputs and targets are assumed to be BxCxWxH
        assert len(outputs.shape) == len(targets.shape)
        # assert that B, W and H are the same
        assert outputs.size(0) == targets.size(0)
        assert outputs.size(2) == targets.size(2)
        assert outputs.size(3) == targets.size(3)
        
        # weights are assumed to be BxWxH
        # assert that B, W and H are the are the same for target and mask
        assert outputs.size(0) == weights.size(0)
        assert outputs.size(1) == weights.size(1)
        assert outputs.size(2) == weights.size(2)
        assert outputs.size(3) == weights.size(3)
        
        if self.use_weight_mask:
            bce_loss = F.binary_cross_entropy_with_logits(input=outputs,
                                                          target=targets,
                                                          weight=weights)            
        else:
            bce_loss = self.nll_loss(input=outputs,
                                     target=targets)

        dice_target = (targets == 1).float()
        dice_output = F.sigmoid(outputs)
        
        intersection = (dice_output * dice_target).sum()
        if self.deduct_intersection:
            union = dice_output.sum() + dice_target.sum() - intersection + self.eps
        else:
            union = dice_output.sum() + dice_target.sum() + self.eps
            
        dice_loss = (-torch.log(2 * intersection / union))         
        
        if self.use_running_mean == False:
            bmw = self.bce_weight
            dmw = self.dice_weight
            # loss += torch.clamp(1 - torch.log(2 * intersection / union),0,100)  * self.dice_weight
        else:
            self.running_bce_loss = self.running_bce_loss * self.gamma + bce_loss.data * (1 - self.gamma)        
            self.running_dice_loss = self.running_dice_loss * self.gamma + dice_loss.data * (1 - self.gamma)

            bm = float(self.running_bce_loss)
            dm = float(self.running_dice_loss)

            bmw = 1 - bm / (bm + dm)
            dmw = 1 - dm / (bm + dm)
                
        loss = bce_loss * bmw + dice_loss * dmw
        
        return loss,bce_loss,dice_loss    

  • Over / under sampling and / or sampling (link) - worked technically, but no accuracy boost
  • Analyzing the internal structure of data and building a cascade of models

Hope this is helpful.

5 Likes