Focal loss performs worse than cross-entropy-loss in clasification

I am working on a CNN based classification.
I am using torchvision.ImageFolder to set up my dataset then pass to the DataLoader and feed it to
pretrained resnet34 model from torchvision.

I have a highly imbalanced dataset which hinders model performance.
Say ‘0’: 1000 images, ‘1’:300 images.
I know I have two broad strategies: work on resampling (data level) or on loss function(algorithm level).
I first tried to change the cross entropy loss to custom FocalLoss. But somehow I am getting even worse performance like below:

my training function looks like this can anybody tell me what I am missing out or doing wrong?

def train_model(model, data_loaders, dataset_sizes, device, n_epochs=20):
    optimizer = optim.Adam(model.parameters(), lr=0.0001)  
    scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    loss_fn = FocalLoss().to(device)
    history = defaultdict(list)
    best_accuracy = 0
    for epoch in range(n_epochs):
        print(f'Epoch {epoch + 1}/{n_epochs}')
        print('-' * 10)
        train_acc, train_loss = train_epoch(
        print(f'Train loss {train_loss} accuracy {train_acc}')
        val_acc, val_loss = eval_model(
        print(f'Val   loss {val_loss} accuracy {val_acc}')
        if val_acc > best_accuracy:
  , 'best_model_state.bin')
            best_accuracy = val_acc
    print(f'Best val accuracy: {best_accuracy}')
    return model, history

The custom FocalLoss function from the web looks like below (sorry, I forgot the reference):

class FocalLoss(nn.Module):
    #WC: alpha is weighting factor. gamma is focusing parameter
    def __init__(self, gamma=0, alpha=None, size_average=True):
    #def __init__(self, gamma=2, alpha=0.25, size_average=False):    
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)                         # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))    # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = logpt.exp()

        if self.alpha is not None:
            if self.alpha.type() !=
                self.alpha = self.alpha.type_as(
            at = self.alpha.gather(0,
            logpt = logpt * at

        loss = -1 * (1 - pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()
1 Like

Take a look at the original paper for more insight.
For focal loss:

" In practice α may be set by inverse class frequency or treated as a hyperparameter to set by cross validation"

Also, your problem is not highly imbalance. You can use weighted cross entropy and get a good performance.

1 Like

when I set alpha to inverse class frequency I get very bad result.

So, if ratio of classes are: [0.84, 0.16] for labels 0 and 1, what do you suggest to set for gamma and alpha?

Also, I am planning to use WandB sweeps for setting up grid search for gamma and alpha. What is a range of value I should run my experiments on? I think gamma is [0, 5] but not sure about alpha.

based on classes’ ratio, your problem is not that unbalance.
I think you can normalize it to have the geometric mean equal to 1.

 alpha = (0.84*0.16)^(-.5)/[0.84 0.16] = [.44,2.3]

or to have expected weight 1 by (this is what’s implemented in weighted cross entropy)

alpha = [0.84 0.16]^(-1)
alpha = alpha/(.84*alpha[0]+ .16*alpha[1])

Also, you should be able to get a good enough result using “weighted cross entropy”.
Focal loss is specialized for object detection with very unbalance classes which many of predicted boxes do not have any object in them and decision boundaries are very hard to learn thus we have probabilities close to .5 for so many of correct decision, that is where focal loss helps us.

BTW, setting weights by inverse class frequency only helps if you want good average Recall( High detection rate).
It may not be good for any metric that use false positive in it (mIoU, f1, precision).


Hello thanks for your response. Would there be a problem if I use “focal loss” for ‘binary classification’ or is it mostly suggested for segmentation and bounding box detection?

It’s used for checking if the box have object in it or not, so It’s binary classification.
In segmentation, if you seeing it as multi-label problem (sigmoid at the end of model or in loss) and using threshold for decision (output>.5), you’re basically implementing N one versus all binary classification.
it’s not a problem.