Focal loss gradient vanishing

I’ve been working on an unbalanced binary classification problem, where true to false ratio is 9:1, and my input is 20 dim tabular data. In order to handle this imbalanced dataset, I decided to use Focal loss. My implementation of this metric is a PyTorch adaptation of the Tensorflow one.

I made a couple of tests and the outcome of my focal loss implementation is the same as the one produced by the TensorFlow one.

class FocalLoss(nn.Module):
    Weighs the contribution of each sample to the loss based in the classification error.
    :gamma: Focusing parameter. gamma=0 is equivalent to BCE_loss
    def __init__(self, gamma, eps=1e-6):
        super(FocalLoss, self).__init__()
        self.gamma = gamma

    def forward(self, y_pred, y_true):
        y_true = y_true.float()
        pred_prob = torch.sigmoid(y_pred)
        ce = nn.BCELoss(reduce=False)(pred_prob,y_true)

        p_t = (y_true*pred_prob)+((1-y_true)*(1-pred_prob))
        modulator = 1.0
        if self.gamma:
            modulator = torch.pow((1.0-p_t),torch.tensor(self.gamma).to(device) )
        return torch.mean(modulator*ce)

My model is a simple n hidden layers, in this test n=4, fully connected NN with Relu activation. This architecture works reasonably fine when cross-entropy is used as a loss function.
The figure below presents a gradient flow after one epoch. The value of validation loss approached 7e-12 and the validation accuracy is 50%.

I am using adam optimizer with lr=1e-4.
What do you think about my implementation of the Focal loss? Is it legit?
I am also using a class balanced sampler, therefore each batch contains an equal number of true and false examples.

1 Like

So, is your implementation correct? How did you conclude it?