Hi,
I used multi-hot labeling for the multi-label cls problem. Initially I was using BCEWithLogitsLoss but as the dataset set is quite imbalanced, it soon predicts all 0. I have tried focal loss as following but the model just does not converge. Is there any suggestion?
def focal_loss(self, pred, gt):
''' Modified focal loss. Exactly the same as CornerNet.
Runs faster and costs a little bit more memory
'''
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_inds
num_pos = pos_inds.sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss