So I have been trying to implement Focal Loss recently (for binary classification), and have found some useful posts here and there, however, each solution differs a little from the other. Here, it’s less of an issue, rather a consultation. My question comes in two parts:
- I would like advice on the correctness of my -inspired by others- implementation (i.e., is it correct? if not, please advise.).
- Is there any improvements to consider (e.g., efficiency or generalizability)?
Here is the snippet (assuming the input to be an output of a Sigmoid function):
class FocalLoss(nn.Module): def __init__(self, alpha=0.25, gamma=2): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma def forward(self, inputs, targets): BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-BCE_loss) F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss return torch.mean(F_loss)
Many thanks in advance