I want to create a custom loss function for multi-label classification. The idea is to weigh the positive and negative labels differently. For this, I am making use of this custom code implementation.
class WeightedBCEWithLogitLoss(nn.Module):
def __init__(self, pos_weight, neg_weight):
super(WeightedBCEWithLogitLoss, self).__init__()
self.register_buffer('neg_weight', neg_weight)
self.register_buffer('pos_weight', pos_weight)
def forward(self, input, target):
assert input.shape == target.shape, "The loss function received invalid input shapes"
y_hat = torch.sigmoid(input + 1e-8)
loss = -1.0 * (self.pos_weight * target * torch.log(y_hat + 1e-6) + self.neg_weight * (1 - target) * torch.log(1 - y_hat + 1e-6))
# Account for 0 times inf which leads to nan
loss[torch.isnan(loss)] = 0
# We average across each of the extra attribute dimensions to generalize it
loss = loss.mean(dim=1)
# We use mean reduction for our task
return loss.mean()
I started getting nan
values which I realized happened because of 0 times inf multiplication. I handled it as shown in the figure. Next, I again saw getting inf
as the error value and corrected it by adding 1e-6 to the log (I tried with 1e-8 but that still gave me inf error value).
It would be great if someone can take a look and suggest further improvements and rectify any more bugs visible here.