So I have been trying to implement Focal Loss recently (for binary classification), and have found some useful posts here and there, however, each solution differs a little from the other. Here, it’s less of an issue, rather a consultation. My question comes in two parts:
- I would like advice on the correctness of my -inspired by others- implementation (i.e., is it correct? if not, please advise.).
- Is there any improvements to consider (e.g., efficiency or generalizability)?
Here is the snippet (assuming the input to be an output of a Sigmoid function):
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
return torch.mean(F_loss)
Many thanks in advance