Is this a correct implementation for focal loss in pytorch?

Hello, I am new to pytorch and currently focusing on text classification task using deep learning networks. The dataset contains two classes and the dataset highly imbalanced(pos:neg==100:1). So I want to use focal loss to have a try.

I have seen some focal loss implementations but they are a little bit hard to write. So I implement the focal loss(Focal Loss for Dense Object Detection) with pytorch==1.0 and python==3.6.5. It works just the same as standard binary cross entropy loss, sometimes worse. Did I correctly implement it?

Here is the code:

class FocalLoss(nn.Module):
"""
binary focal loss
"""

def __init__(self, alpha=0.25, gamma=2):
    super(FocalLoss, self).__init__()
    self.weight = torch.Tensor([alpha, 1-alpha])
    self.nllLoss = nn.NLLLoss(weight=self.weight)
    self.gamma = gamma

def forward(self, input, target):
    softmax = F.softmax(input, dim=1)
    log_logits = torch.log(softmax)
    fix_weights = (1 - softmax) ** self.gamma
    logits = fix_weights * log_logits
    return self.nllLoss(logits, target)
1 Like

i am also in interested in knowing the implementation of focal loss.it would be great if someone who has implemented it, could help us out here !

You can checkout kornia’s implementation here

its input has shape (batch_size, num_classes, H, W). but i have an input matrix of shape (batch_size, num_classes) and target matrix of same shape (batch_size, num_classes). So i have come up with this…


import torch, sys, os, pdb
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):

    def __init__(self, gamma = 1.0):
        super(FocalLoss, self).__init__()
        self.gamma = torch.tensor(gamma, dtype = torch.float32)
        self.eps = 1e-6

    def forward(self, input, target):
        # input are not the probabilities, they are just the cnn out vector
        # input and target shape: (bs, n_classes)
        # sigmoid
        probs = torch.sigmoid(input)
        log_probs = -torch.log(probs)

        focal_loss = torch.sum(  torch.pow(1-probs + self.eps, self.gamma).mul(log_probs).mul(target)  , dim=1)
        # bce_loss = torch.sum(log_probs.mul(target), dim = 1)
        
        return focal_loss.mean() #, bce_loss

pls let me know if its good or not

Try this:

BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss) # prevents nans when probability 0
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
return focal_loss.mean()

Remember the alpha to address class imbalance and keep in mind that this will only work for binary classification.
I think this is very similar to your implementation it just uses the BCE function which does the sigmoid and those .mul() for you and also prevents the nan problem that is present in your implementation when probability is 0 i.e log(0) = -inf
There are also some very nice implementations that work for multiclass problems. These implementations can be found here and here

14 Likes

Hey Diego,

I have a couple of questions about your implementation.

  1. What is alpha for a binary focal loss? In binary segmentation there is only one class (the one you’re trying to segment)
  2. Your equation isn’t using epsilon, is that intentional?

Thanks!

Hi, just to clarify this is not my implementation, I found it somewhere in kaggle, I cant find the link now. I will try to answer your questions nonetheless:

  1. Alpha is hyperparameter that you can tune to assign more importance to samples from class A or B. I dont know anything about binary segmentation so correct me if i am wrong, but I assume that it must have at least 2 classes. The class you are trying to segment and background.
  2. The epsilon is used to avoid numerical instability if probability equals 0. In this case torch.exp() will deal with that.

You can find another, perhaps more clear, implementation here

Hi,
Here is my implementation. I have tried to use the info on torch.nn.functional.log_softmax from https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.log_softmax. So far, it’s working well in a class - imbalance problem.
Please let me know if it works for you.

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    
    def __init__(self, weight=None, 
                 gamma=2., reduction='none'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob, 
            target_tensor, 
            weight=self.weight,
            reduction = self.reduction
        )
8 Likes

Looks working for me too. My Input size is NxC, and target size is N.

Hey, thanks for the implementation. Could you tell something about weights variable ? What do you use for that ?

Weights should be a 1-d tensor indicating the relative class importance. For a balanced case i.e. weight=None, it’s equivalent to a 1-d tensor whose values are all equal e.g. 1. For class-imbalance problems, this can be tweaked to adjust for the imbalance i.e. [0.5, 1] in a binary classification problem where the first class is twice more likely to appear than the second in the target variable. Take a look at the mathematical formulations here: https://pytorch.org/docs/stable/nn.html#nllloss

Don’t you use alpha in your implementation?

kornia’s implementation does not seem to treat alpha properly.
It seems to use the same value for both classes.

Here is my implementation along the lines of @Diego with alpha implemented. Binary case only.

def focal_loss(bce_loss, targets, gamma, alpha):
    """Binary focal loss, mean.

    Per https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/5 with
    improvements for alpha.
    :param bce_loss: Binary Cross Entropy loss, a torch tensor.
    :param targets: a torch tensor containing the ground truth, 0s and 1s.
    :param gamma: focal loss power parameter, a float scalar.
    :param alpha: weight of the class indicated by 1, a float scalar.
    """
    p_t = torch.exp(-bce_loss)
    alpha_tensor = (1 - alpha) + targets * (2 * alpha - 1)  # alpha if target = 1 and 1 - alpha if target = 0
    f_loss = alpha_tensor * (1 - p_t) ** gamma * bce_loss
    return f_loss.mean()
2 Likes

I believe alpha and the weight parameter is essentially doing the same thing here. they are just linearly scaling the loss. This is a neat way of implementing alpha. :smiley:

Here is an implementation of Focal Loss for muti-class classification:

1_gO_nxGFmpAelOrU_D9O5-Q

Here, -log(pt) is our ordinary cross entropy loss.

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2,reduction='mean'):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights

    def forward(self, input, target):

        ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight) 
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss
3 Likes

Does this work properly for minibatches?
I guess F.cross_entropy() gives the average c-e entropy over the batch, and pt is a scalar variable that modifies the loss for the batch. So, if some of the input-target patterns have a low and some have a high ce_loss they get the same focal adjustment?
If so, this might fix it:

def forward(self, input, target):
    patterns = target.shape[0]
    tot = 0
    for b in range(patterns):
        ce_loss = F.cross_entropy(input[b:b+1,], target[b:b+1],reduction=self.reduction,weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        tot = tot + focal_loss
    return tot/patterns

Not sure, but is pt not the network’s prediction for class t?

F.cross_entropy takes logits from the model. Logits are outputs of the model, they are not probabilities. That’s the reason, for probabilities (i.e. pt), torch.exp(-ce_loss) is done.

Hope this helps.

1 Like