Focal loss performs worse than cross-entropy-loss in clasification

Hello,
I am working on a CNN based classification.
I am using torchvision.ImageFolder to set up my dataset then pass to the DataLoader and feed it to
pretrained resnet34 model from torchvision.

I have a highly imbalanced dataset which hinders model performance.
Say ‘0’: 1000 images, ‘1’:300 images.
I know I have two broad strategies: work on resampling (data level) or on loss function(algorithm level).
I first tried to change the cross entropy loss to custom FocalLoss. But somehow I am getting even worse performance like below:
focalloss

my training function looks like this can anybody tell me what I am missing out or doing wrong?

def train_model(model, data_loaders, dataset_sizes, device, n_epochs=20):
    
    optimizer = optim.Adam(model.parameters(), lr=0.0001)  
    scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    loss_fn = FocalLoss().to(device)
    history = defaultdict(list)
    best_accuracy = 0
    for epoch in range(n_epochs):
        print(f'Epoch {epoch + 1}/{n_epochs}')
        print('-' * 10)
        train_acc, train_loss = train_epoch(
          model,
          data_loaders['train'],
          loss_fn,
          optimizer,
          device,
          scheduler,
          dataset_sizes['train']
        )
        print(f'Train loss {train_loss} accuracy {train_acc}')
        val_acc, val_loss = eval_model(
          model,
          data_loaders['val'],
          loss_fn,
          device,
          dataset_sizes['val']
        )
        print(f'Val   loss {val_loss} accuracy {val_acc}')
        print()
        history['train_acc'].append(train_acc)
        history['train_loss'].append(train_loss)
        history['val_acc'].append(val_acc)
        history['val_loss'].append(val_loss)
        if val_acc > best_accuracy:
            torch.save(model.state_dict(), 'best_model_state.bin')
            best_accuracy = val_acc
    print(f'Best val accuracy: {best_accuracy}')
    model.load_state_dict(torch.load('best_model_state.bin'))
    return model, history

The custom FocalLoss function from the web looks like below (sorry, I forgot the reference):

class FocalLoss(nn.Module):
    #WC: alpha is weighting factor. gamma is focusing parameter
    def __init__(self, gamma=0, alpha=None, size_average=True):
    #def __init__(self, gamma=2, alpha=0.25, size_average=False):    
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)                         # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))    # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = logpt.exp()

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * at

        loss = -1 * (1 - pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()
2 Likes

Take a look at the original paper for more insight.
For focal loss:

" In practice α may be set by inverse class frequency or treated as a hyperparameter to set by cross validation"

Also, your problem is not highly imbalance. You can use weighted cross entropy and get a good performance.

1 Like

when I set alpha to inverse class frequency I get very bad result.

So, if ratio of classes are: [0.84, 0.16] for labels 0 and 1, what do you suggest to set for gamma and alpha?

Also, I am planning to use WandB sweeps for setting up grid search for gamma and alpha. What is a range of value I should run my experiments on? I think gamma is [0, 5] but not sure about alpha.

1 Like

Hi,
based on classes’ ratio, your problem is not that unbalance.
I think you can normalize it to have the geometric mean equal to 1.

 alpha = (0.84*0.16)^(-.5)/[0.84 0.16] = [.44,2.3]

or to have expected weight 1 by (this is what’s implemented in weighted cross entropy)

alpha = [0.84 0.16]^(-1)
alpha = alpha/(.84*alpha[0]+ .16*alpha[1])

Also, you should be able to get a good enough result using “weighted cross entropy”.
Focal loss is specialized for object detection with very unbalance classes which many of predicted boxes do not have any object in them and decision boundaries are very hard to learn thus we have probabilities close to .5 for so many of correct decision, that is where focal loss helps us.

BTW, setting weights by inverse class frequency only helps if you want good average Recall( High detection rate).
It may not be good for any metric that use false positive in it (mIoU, f1, precision).

4 Likes

Hello thanks for your response. Would there be a problem if I use “focal loss” for ‘binary classification’ or is it mostly suggested for segmentation and bounding box detection?

It’s used for checking if the box have object in it or not, so It’s binary classification.
In segmentation, if you seeing it as multi-label problem (sigmoid at the end of model or in loss) and using threshold for decision (output>.5), you’re basically implementing N one versus all binary classification.
it’s not a problem.

Sorry for reviving this old topic, but for posterity I thought it would be worth correcting a misconception that also had me confused.

While the original paper does state:

In practice α may be set by inverse class frequency or treated as a hyperparameter to set by cross validation"

In fact, this can be quite misleading if you don’t read the paper carefully. What they are referring to is the pre-existing practice used with the regular weighted cross entropy loss.

With their focal loss formulation they actually find that in practice decreasing alpha as gamma is increased helps as a form of compensation. That is, gamma already strongly up-modulates hard-to-classify samples (which typically happen to be the more rare positives examples), so counter-intuively alpha is used to down weight the more rare positive classes in the paper.

I actually find this quite distressing, but that’s the empirical finding.

Finally we note that α, the weight assigned to the rare class, also has a stable range, but it interacts with γ making it necessary to select the two together. In general α should be decreased slightly as γ is increased (for γ = 2, α = 0.25 works best).

We observe that lower α’s are selected for higher γ’s (as easy negatives are downweighted, less emphasis needs to be placed on the positives). Overall, however, the benefit of changing γ is much
larger, and indeed the best α’s ranged in just [.25,.75] (we tested α ∈ [.01, .999]). We use γ = 2.0 with α = .25 for all experiments but α = .5 works nearly as well (.4 AP lower).

Naive speculation

Personally I’ve been wondering if alpha is even a good idea in the focal loss formulation. Somehow it feels like perhaps there should be some other way to up-/down-modulate gamma (perhaps based on optimizer momentum or something). However, this is just some naive speculation on my part.