Hello, I am new to pytorch and currently focusing on text classification task using deep learning networks. The dataset contains two classes and the dataset highly imbalanced(pos:neg==100:1). So I want to use focal loss to have a try.
I have seen some focal loss implementations but they are a little bit hard to write. So I implement the focal loss(Focal Loss for Dense Object Detection) with pytorch==1.0 and python==3.6.5. It works just the same as standard binary cross entropy loss, sometimes worse. Did I correctly implement it?
its input has shape (batch_size, num_classes, H, W). but i have an input matrix of shape (batch_size, num_classes) and target matrix of same shape (batch_size, num_classes). So i have come up with this…
import torch, sys, os, pdb
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, gamma = 1.0):
super(FocalLoss, self).__init__()
self.gamma = torch.tensor(gamma, dtype = torch.float32)
self.eps = 1e-6
def forward(self, input, target):
# input are not the probabilities, they are just the cnn out vector
# input and target shape: (bs, n_classes)
# sigmoid
probs = torch.sigmoid(input)
log_probs = -torch.log(probs)
focal_loss = torch.sum( torch.pow(1-probs + self.eps, self.gamma).mul(log_probs).mul(target) , dim=1)
# bce_loss = torch.sum(log_probs.mul(target), dim = 1)
return focal_loss.mean() #, bce_loss
Remember the alpha to address class imbalance and keep in mind that this will only work for binary classification.
I think this is very similar to your implementation it just uses the BCE function which does the sigmoid and those .mul() for you and also prevents the nan problem that is present in your implementation when probability is 0 i.e log(0) = -inf
There are also some very nice implementations that work for multiclass problems. These implementations can be found here and here
Hi, just to clarify this is not my implementation, I found it somewhere in kaggle, I cant find the link now. I will try to answer your questions nonetheless:
Alpha is hyperparameter that you can tune to assign more importance to samples from class A or B. I dont know anything about binary segmentation so correct me if i am wrong, but I assume that it must have at least 2 classes. The class you are trying to segment and background.
The epsilon is used to avoid numerical instability if probability equals 0. In this case torch.exp() will deal with that.
You can find another, perhaps more clear, implementation here
Weights should be a 1-d tensor indicating the relative class importance. For a balanced case i.e. weight=None, it’s equivalent to a 1-d tensor whose values are all equal e.g. 1. For class-imbalance problems, this can be tweaked to adjust for the imbalance i.e. [0.5, 1] in a binary classification problem where the first class is twice more likely to appear than the second in the target variable. Take a look at the mathematical formulations here: https://pytorch.org/docs/stable/nn.html#nllloss
I believe alpha and the weight parameter is essentially doing the same thing here. they are just linearly scaling the loss. This is a neat way of implementing alpha.
Does this work properly for minibatches?
I guess F.cross_entropy() gives the average c-e entropy over the batch, and pt is a scalar variable that modifies the loss for the batch. So, if some of the input-target patterns have a low and some have a high ce_loss they get the same focal adjustment?
If so, this might fix it:
def forward(self, input, target):
patterns = target.shape[0]
tot = 0
for b in range(patterns):
ce_loss = F.cross_entropy(input[b:b+1,], target[b:b+1],reduction=self.reduction,weight=self.weight)
pt = torch.exp(-ce_loss)
focal_loss = (1 - pt) ** self.gamma * ce_loss
tot = tot + focal_loss
return tot/patterns
F.cross_entropy takes logits from the model. Logits are outputs of the model, they are not probabilities. That’s the reason, for probabilities (i.e. pt), torch.exp(-ce_loss) is done.
This may not be correct if you have class weight, since the weight is attached before multiplying by (1-p)**gamma. I think @arkrde 's answer works. In pytorch, F.cross_entropy is equal to F.nll_loss(F.log_softmax). The weight (i.e. the alpha) should be attached on nll_loss().