Focal loss like classification loss

I have the following focal loss like implementation:

def focal_loss(pred:torch.Tensor, target:torch.Tensor, α:float=2.0, β:float=4.0) -> torch.Tensor:
  '''
  Treats the tensors as a contiguous array.
  Arguments:
    pred    (batch x c x h x w) in [0, 1]
    target  (batch x c x h x w) in [0, 1]
  '''

  pos_mask = target.eq(1).float()
  neg_mask = target.lt(1).float()

  neg_weights = torch.pow(1 - target, β)

  pos_loss = torch.log(pred) * torch.pow(1 - pred, α) * pos_mask
  neg_loss = torch.log(1 - pred) * torch.pow(pred, α) * neg_weights * neg_mask

  num_pos  = pos_mask.float().sum()
  pos_loss = pos_loss.sum()
  neg_loss = neg_loss.sum()

  total_loss = 0
  if (num_pos == 0):
    total_loss = total_loss - neg_loss
  else:
    total_loss = total_loss - (pos_loss + neg_loss) / num_pos

  return total_loss

Whenever I use it I get very bad results compared to using binary_cross_entropy.

The concept is from CornerNet: Detecting Objects as Paired Keypoints:

I suspect it is a numerical thing.
Any way to make calculation more robust?

Any alternative loss to such case?