Implementing Focal Loss for a binary classification problem

So I have been trying to implement Focal Loss recently (for binary classification), and have found some useful posts here and there, however, each solution differs a little from the other. Here, it’s less of an issue, rather a consultation. My question comes in two parts:

  1. I would like advice on the correctness of my -inspired by others- implementation (i.e., is it correct? if not, please advise.).
  2. Is there any improvements to consider (e.g., efficiency or generalizability)?

Here is the snippet (assuming the input to be an output of a Sigmoid function):

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2): 
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
    def forward(self, inputs, targets): 
        BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return torch.mean(F_loss)

Many thanks in advance :smile:

1 Like

I just came across the official implementation of FOCAL_LOSS from TORCHVISION :
https://pytorch.org/vision/main/_modules/torchvision/ops/focal_loss.html