Is this a correct implementation for focal loss in pytorch?

Hi,
Here is my implementation. I have tried to use the info on torch.nn.functional.log_softmax from https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.log_softmax. So far, it’s working well in a class - imbalance problem.
Please let me know if it works for you.

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    
    def __init__(self, weight=None, 
                 gamma=2., reduction='none'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob, 
            target_tensor, 
            weight=self.weight,
            reduction = self.reduction
        )
13 Likes