Since sigmoid
will be applied twice in this (wrong) approach, you might have scaled down the gradients, thus stabilized the training, e.g. if your learning rate was too high.
Here is a small example showing this effect:
model = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
nn.Linear(10, 1)
)
data = torch.randn(1, 10)
target = torch.randint(0, 2, (1, 1)).float()
# 1) nn.BCEWithLogitsLoss
output = model(data)
loss = F.binary_cross_entropy_with_logits(output, target)
loss.backward()
print(model[0].weight.grad.norm())
> tensor(0.1741)
print(model[2].weight.grad.norm())
> tensor(0.2671)
# 2) nn.BCELoss
model.zero_grad()
output = model(data)
loss = F.binary_cross_entropy(torch.sigmoid(output), target)
loss.backward()
print(model[0].weight.grad.norm())
> tensor(0.1741)
print(model[2].weight.grad.norm())
> tensor(0.2671)
# 3) wrong
model.zero_grad()
output = model(data)
loss = F.binary_cross_entropy_with_logits(torch.sigmoid(output), target)
loss.backward()
print(model[0].weight.grad.norm())
> tensor(0.0595)
print(model[2].weight.grad.norm())
> tensor(0.0914)
Your loss might blow up and get eventually a NaN
value, e.g. if the learning rate is set too high, which would also fit my assumption.
While applying sigmoid
twice might have helped in your use case, I would recommend to try to debug the exploding loss (or NaN
values).