Hello, I am new to pytorch and currently focusing on text classification task using deep learning networks. The dataset contains two classes and the dataset highly imbalanced(pos:neg==100:1). So I want to use focal loss to have a try.
I have seen some focal loss implementations but they are a little bit hard to write. So I implement the focal loss(Focal Loss for Dense Object Detection) with pytorch==1.0 and python==3.6.5. It works just the same as standard binary cross entropy loss, sometimes worse. Did I correctly implement it?
Here is the code:
class FocalLoss(nn.Module): """ binary focal loss """ def __init__(self, alpha=0.25, gamma=2): super(FocalLoss, self).__init__() self.weight = torch.Tensor([alpha, 1-alpha]) self.nllLoss = nn.NLLLoss(weight=self.weight) self.gamma = gamma def forward(self, input, target): softmax = F.softmax(input, dim=1) log_logits = torch.log(softmax) fix_weights = (1 - softmax) ** self.gamma logits = fix_weights * log_logits return self.nllLoss(logits, target)