I want to confirm the below implementation for a Multi-label Focal Loss function that also accepts the class_weights parameter to handle class imbalance (@ptrblck would like to get your feedback if possible ):
class MultiLabelFocalLoss(torch.nn.Module):
āā"Implementation of a Multi-label Focal loss functionArgs: weight: class weight vector to be used in case of class imbalance gamma: hyper-parameter for the focal loss scaling. """ def __init__(self, weight=None, gamma=2): super(MultiLabelFocalLoss, self).__init__() self.gamma = gamma self.weight = weight #weight parameter will act as the alpha parameter to balance class weights self.loss = torch.nn.BCELoss(reduction='none') def forward(self, outputs, targets): ce_loss = self.loss(outputs, targets) * self.weight pt = torch.exp(-ce_loss) focal_loss = ((1-pt)**self.gamma * ce_loss).mean() # mean over the batch return focal_loss