That may be the numerical unstable. Applying clamp may help:
output = torch.clamp(output,min=1e-8,max=1-1e-8)
loss = pos_weight * (target * torch.log(output)) + neg_weight* ((1 - target) * torch.log(1 - output))
That may be the numerical unstable. Applying clamp may help:
output = torch.clamp(output,min=1e-8,max=1-1e-8)
loss = pos_weight * (target * torch.log(output)) + neg_weight* ((1 - target) * torch.log(1 - output))