I have the following focal loss like implementation:
def focal_loss(pred:torch.Tensor, target:torch.Tensor, α:float=2.0, β:float=4.0) -> torch.Tensor:
'''
Treats the tensors as a contiguous array.
Arguments:
pred (batch x c x h x w) in [0, 1]
target (batch x c x h x w) in [0, 1]
'''
pos_mask = target.eq(1).float()
neg_mask = target.lt(1).float()
neg_weights = torch.pow(1 - target, β)
pos_loss = torch.log(pred) * torch.pow(1 - pred, α) * pos_mask
neg_loss = torch.log(1 - pred) * torch.pow(pred, α) * neg_weights * neg_mask
num_pos = pos_mask.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
total_loss = 0
if (num_pos == 0):
total_loss = total_loss - neg_loss
else:
total_loss = total_loss - (pos_loss + neg_loss) / num_pos
return total_loss
Whenever I use it I get very bad results compared to using binary_cross_entropy
.
The concept is from CornerNet: Detecting Objects as Paired Keypoints:
I suspect it is a numerical thing.
Any way to make calculation more robust?
Any alternative loss to such case?