How to implement focal loss in pytorch?

I implemented multi-class Focal Loss in pytorch. Bellow is the code. log_pred_prob_onehot is batched log_softmax in one_hot format, target is batched target in number(e.g. 0, 1, 2, 3).

class FocalLoss(torch.nn.Module):
    def __init__(self, gamma=2):
        self.gamma = gamma

    def forward(self, log_pred_prob_onehot, target):
        pred_prob_oh = torch.exp(log_pred_prob_onehot)
        pt = Variable(,, 1)), requires_grad=True)
        modulator = (1 - pt) ** self.gamma
        mce = modulator * (-torch.log(pt))

        return mce.mean()

However, when I tested it, it worked poorly. I read the Focal Loss paper a couple of times. It seems straight. Maybe I didn’ t understand it very well. I’d appreciate if anybody can correct me! Or if there is a workable implementation, please let me know! Thanks in advance!


4 Likes maybe this repo can solve your problem?

1 Like

Thanks a lot @BowieHsu!

I also found this one is very good.

I ported this code to my program and it works.

1 Like

I guess there is something wrong in the original code which breaks the computation graph and makes loss not decrease. I doubt it is this line:

    pt = Variable(,, 1)), requires_grad=True)

Is torch.gather support autograd? Is there anyway to implement this?
Many thanks!

1 Like

@ben Hi, ben,
have you tried focal loss on SSD or FasteRcnn?
How much mAP will be improved?

@BowieHsu I used it in my own project, which has a multi-class, unbalanced data set. So far, it is not as good as it is mentioned in the paper. It is still very hard to train the data.

1 Like

I also tried to use it in my own project, I found I had to reduce the lr by a factor of 10, leading to a better first iteration, but then due to the reduced lr, the precision over the epochs is barely improving. Maybe increasing lr after the first epoch could improve.

I tried it in my project, in FPN in faster rcnn, not as good as cross entropy though.

i haven’t read the paper in deatils. But I thought the the term (1-p)^gamma and p^gamma are for weighing only. They should not be back propagated during gradient descent. Maybe you need to detach() your variables?

after some checking, the weighing terms (1-p)^gamma and p^gamma are back propagated as well. you can refer to:


Hi Ben.
Have you confirmed that training with gamma=0 is same to with cross entropy loss?
I tried that in my implementation of focal loss. The result became very different :scream::scream:

And I ask someone to answer my forum question. I can’t identify the problem.

I found this one is pretty good, except some small grammar issues in python3. Enjoy!


Thank you for helping me, Ben!!

I completed implmentation of focal loss for semantic segmentation.

you can find now one_hot and focal loss implementations in torchgeometry:


I’d like to add something to this, since it was leading me to an error.
The implementation suggested by @Ben (marvis/pytorch-yolo2/blob/master/ beware specify alpha as a (C,1)-shaped tensor rather than a (C,).
Otherwise, the implementation will be still working, but the loss will be computed as a dot product between the batch alpha values and the batch class probabilities, which makes conceptually no sense.

quick update: this can be found in kornia:

1 Like
class FocalLoss(torch.nn.Module):
    def __init__(self, gamma=2):
        super(FocalLoss, self).__init__()
        self.gamma = gamma

    def forward(self, inputs, targets):
        ps = F.softmax(inputs, dim=1)
        ne_ps = 1 - ps
        ws = torch.pow(ne_ps, self.gamma)
        return F.nll_loss(ps.pow(ws), targets)

Can anyone give me comment on my implementation?
I was using it to solve classification problem with imbalance class
but when I checked the TensorBoard logging, the loss is always the same
I changed to nll_loss and the loss is changing now, but I am still looking forward to having others comments :slight_smile:

@edwardpwtsoi I think if u are using NLL loss then u have to do logsoftmax.

How about the ignore_index apply to focal loss?

Just FYI: This comment here:

Sais that there is an implementation of focal loss in torchvision.