I am working on Binary semantic segmentation and my dataset is highly imbalanced i.e. foreground pixels are very less. So I want to try the focal loss implementation as defined below but loss becomes zero after 1/2 epochs. is my implementation is correct, if yes how do I troubleshoot this?
def to_one_hot(tensor,nClasses):
n,c,h,w = tensor.size()
one_hot = torch.zeros(n,nClasses,h,w).cuda().scatter_(1,tensor.view(n,1,h,w),1)
return one_hot
class FocalLoss(nn.Module):
def __init__(self,classes, gamma=2,alpha=0.75, eps=1e-7):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.eps = eps
self.classes = classes
self.alpha = alpha
def forward(self, input_, target_):
y = Variable(to_one_hot(target_.data, self.classes))
logit = input_.sigmoid()
logit = logit.clamp(self.eps, 1. - self.eps)
loss = -1 * torch.log(logit) * y.float() # cross entropy
loss = self.alpha * loss * (1 - logit) ** self.gamma # focal loss
loss = torch.mean(loss)
return loss
Thanks in advance