Custom loss functions

Sure, here is the raw implementation rewritten directly from the docs as well as the stable internal implementation, which doesn’t overflow for large values:

def my_bce_with_logits_loss(x, y):
    loss = -1.0 * (y * F.logsigmoid(x) + (1 - y) * torch.log(1 - torch.sigmoid(x)))
    loss = loss.mean()
    return loss

def my_bce_with_logits_loss_stable(x, y):
    max_val = (-x).clamp_min_(0)
    loss = (1 - y) * x + max_val + torch.log(torch.exp(-max_val) + torch.exp(-x - max_val))
    loss = loss.mean()
    return loss


criterion = nn.BCEWithLogitsLoss()

batch_size = 5
nb_classes = 1

# small values
x = torch.randn(batch_size, nb_classes, requires_grad=True)
y = torch.empty(batch_size, nb_classes).uniform_(0, 1)

loss_reference = criterion(x, y)
loss = my_bce_with_logits_loss(x, y)
loss_stable = my_bce_with_logits_loss_stable(x, y)

print(loss_reference)
>tensor(1.0072, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

print(loss_reference - loss)
> tensor(0., grad_fn=<SubBackward0>)

print(loss_reference - loss_stable)
> tensor(0., grad_fn=<SubBackward0>)


# large values
x = torch.randn(batch_size, nb_classes, requires_grad=True) * 100
y = torch.empty(batch_size, nb_classes).uniform_(0, 1)

loss_reference = criterion(x, y)
loss = my_bce_with_logits_loss(x, y)
loss_stable = my_bce_with_logits_loss_stable(x, y)

print(loss_reference)
> tensor(12.1431, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

print(loss_reference - loss)
> tensor(-inf, grad_fn=<SubBackward0>)

print(loss_reference - loss_stable)
> tensor(0., grad_fn=<SubBackward0>)