Is this a correct implementation for focal loss in pytorch?

Try this:

BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss) # prevents nans when probability 0
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
return focal_loss.mean()

Remember the alpha to address class imbalance and keep in mind that this will only work for binary classification.
I think this is very similar to your implementation it just uses the BCE function which does the sigmoid and those .mul() for you and also prevents the nan problem that is present in your implementation when probability is 0 i.e log(0) = -inf
There are also some very nice implementations that work for multiclass problems. These implementations can be found here and here

15 Likes