Focal Loss for Binary Semantic segmentation Implementation?


(Anil Batra) #1

I am working on Binary semantic segmentation and my dataset is highly imbalanced i.e. foreground pixels are very less. So I want to try the focal loss implementation as defined below but loss becomes zero after 1/2 epochs. is my implementation is correct, if yes how do I troubleshoot this?

def to_one_hot(tensor,nClasses):
  n,c,h,w = tensor.size()
  one_hot = torch.zeros(n,nClasses,h,w).cuda().scatter_(1,tensor.view(n,1,h,w),1)
  return one_hot
    
class FocalLoss(nn.Module):

    def __init__(self,classes, gamma=2,alpha=0.75, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.classes = classes
        self.alpha = alpha

    def forward(self, input_, target_):
        
        y = Variable(to_one_hot(target_.data, self.classes))

        logit = input_.sigmoid()
        logit = logit.clamp(self.eps, 1. - self.eps)
        loss = -1 * torch.log(logit) * y.float() # cross entropy
        loss = self.alpha * loss * (1 - logit) ** self.gamma # focal loss
        loss = torch.mean(loss)
        
        return loss

Thanks in advance :slight_smile:


(Wu Xiaodong) #2

I think this repo may help you.


(Xiaokang Wang) #3

Hi anil_baatra,

Thanks for posting your implementation. You implementation looks good. Have you figure out why it becomes zeros?

Thanks,

Xiaokang


(Anil Batra) #4

Thanks @Xiaokang_Wang for your response. I was not able to figure out the issue with focal loss, but I was able to to get decent recent results with few architecture modifications. Do you have suggestions to trouble shoot the same?

Thanks
Anil