Is this a correct implementation for focal loss in pytorch?

Hi, this is my implementation:

class FocalLoss(nn.Module):
    """Implementation of Facal Loss"""
    def __init__(self, weight=None, gamma=2, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.weighted_cs = nn.CrossEntropyLoss(weight=weight, reduction="none")
        self.cs = nn.CrossEntropyLoss(reduction="none")
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, predicted, target):
        """
        predicted: [batch_size, n_classes]
        target: [batch_size]
        """
        pt = 1/torch.exp(self.cs(predicted,target))
        #shape: [batch_size]
        entropy_loss = self.weighted_cs(predicted, target)
        #shape: [batch_size]
        focal_loss = ((1-pt)**self.gamma)*entropy_loss
        #shape: [batch_size]
        if self.reduction =="none":
            return focal_loss
        elif self.reduction == "mean":
            return focal_loss.mean()
        else:
            return focal_loss.sum()
1 Like

sorry to disturb you. My classdatasets have server unbalance. Using classweight to balance data weights, it doenā€™t work. My task is emotion recognition (multitasks). The main architecture of the model is based on GRU. Could you please give me some suggestion?
Thanks
best wishes

image

Not sure when it happened but itā€™s available now in torchvision.ops.sigmoid_focal_loss.

This is my implementation of multi-class focal loss function using only the pytroch loss function ā€œtorch.nn.NLLLoss`ā€: in my code, (L-h) is the third dimension which is the length of sequence.

# C is number of classes
# w is the alpha_t in the main paper (should sum up to 1)
# weight_focal  is (1-p_t)^gamma in the paper
# prediction is the raw output of model (without sigmoid layer)

loss_nll = nn.NLLLoss(weight=w,ignore_index=-1, reduction='none')   # w.shape = [C]
        gamma = 2
        softmax_pred = nn.Softmax(dim=-1)(prediction) # [B, L-h, C]
        logsoftmax_pred = nn.LogSoftmax(dim=-1)(prediction) # [B, L-h, C]
        weight_focal = torch.pow( (1.0 - softmax_pred) , gamma)    # [B, L-h, C] 
        # truth_hot =  F.one_hot(truth, num_classes = n_classes+1).to(torch.float32)
        loss = loss_nll(weight_focal.transpose(1, 2) * logsoftmax_pred.transpose(1, 2) , truth) #[B, L-h] 
        # transpose is neccessary because the second dimension should be the classes

Hi
Actually, I canā€™t get the same results with your code and another library that is mentioned here. I donā€™t know why is this the case but would you please check that?
Thanks

Hi Diego, I tried focal loss for one of my binary classification problems. The validation loss is reduced to 0.05 from 0.3, however, thereā€™s no change in accuracy. Can you check my code implementation below to see if itā€™s correct? Thank you.

class FocalLoss(nn.Module):

    def __init__(self, alpha=0.25, gamma=1):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.eps = 1e-6

    def forward(self, inputs, targets):
        BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss) 
        F_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss
        return F_loss.mean()

loss_fn = FocalLoss(alpha=0.25, gamma=1)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True)

def train(data_loader, is_validation=False):
    model.train() if not is_validation else model.eval()
    correct = 0
    total = 0
    loss_total = 0

    for batch in data_loader:
        optimizer.zero_grad()
        pred = model(batch.x.float(), batch.edge_index.long(), batch.batch.long())  
        target = batch.y.float()
        loss = loss_fn(pred, target)

        if not is_validation:
            loss.backward()
            optimizer.step()

        predicted = (pred >= 0.0).float() 
        total += target.size(0)
        correct += (predicted == target).sum().item()
        loss_total += loss.item()

    accuracy = correct/total
    return loss_total / len(data_loader), accuracy

def validate(data_loader):
    return train(data_loader, is_validation=True)