Implemented Focal Loss for Semantic Segmentation (Question)

class FocalLoss(nn.modules.loss._WeightedLoss):

    def __init__(self, weight=None, gamma=2,reduction=None):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.axis = 1
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights
    def forward(self, input, target):
        input = input.transpose(self.axis,-1).contiguous()
        target = target.transpose(self.axis,-1).contiguous()
        input = input.view(-1,input.shape[-1])

        ce_loss = nn.CrossEntropyLoss(ignore_index=255) 
        loss=ce_loss(input, target.view(-1)
        pt = torch.exp(-loss)

        focal_loss = ((1 - pt) ** self.gamma * loss).mean()
        return focal_loss

Input shape : B, C(21), H, W Target shape: B, C(1), H, W
I 've implemented this to try on u-net with voc-2012 semantic segmentation dataset.
Yes it works, but somehow this doesn’t feels right.

It seems like it’s just getting the exponential of the negative cross entropy loss and doing the equation of focal loss afterwards. This doesn’t makes sense because VOC segmentation images has more than one classes per image, and if so then the weighting part for loss is not going to happen outside the cross-entropy. Then how is the equation:
focal_loss = ((1 - pt) ** self.gamma * loss).mean() supposed to do his part? I am so confused please help me :confused:

Edit: I think pt variable supposed be the prediction probability, but implementations i came across on the internet did like this. How can i make the pt variable a probability distribution from the output of cross entropy function? Or should i use input=LogSoftmax(input, dim = 1) kinda code at the beginning of forward?