Loss for imbalanced multi-label classification


I used multi-hot labeling for the multi-label cls problem. Initially I was using BCEWithLogitsLoss but as the dataset set is quite imbalanced, it soon predicts all 0. I have tried focal loss as following but the model just does not converge. Is there any suggestion?

def focal_loss(self, pred, gt):
        ''' Modified focal loss. Exactly the same as CornerNet.
          Runs faster and costs a little bit more memory
        pos_inds = gt.eq(1).float()
        neg_inds = gt.lt(1).float()

        loss = 0

        pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
        neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_inds

        num_pos  = pos_inds.sum()
        pos_loss = pos_loss.sum()
        neg_loss = neg_loss.sum()

        if num_pos == 0:
            loss = loss - neg_loss
            loss = loss - (pos_loss + neg_loss) / num_pos

        return loss

Hi Zhiyuan!

Consider using BCEWithLogitsLoss with its pos_weight (or possibly
its weight) constructor-argument to account for the data imbalance.


K. Frank