Hi,
I am trying to implement a focal loss. Below is my approach. I have just modified the cross - entropy loss.
Is it correct?
import torch.nn.functional as F
import torch.nn as nn
class FocalLoss():
__constants__ = ['weight','ignore_index','reduction']
def __init__(self, weight = None, size_average = None, gamma = 0, alpha = 0):
super(FocalLoss, self).__init__(weight,size_average, reduce, reduction)
self.gamma = gamma
self.alpha = alpha
self.ignore_index = ignore_index
def forward(self, logits, labels):
loss = F.cross_entropy(logits, target, weight = self.weight,reduction = self.reduction, ignore_index = self.ignore_index)
loss = ((1-logit) ** self.gamma) * alpha * loss
return loss