Hi Everyone,
I have been trying to replace F.binary_cross_entropy
by my own binary cross entropy custom loss since I want to adapt it and make appropriate changes. I feel that having it as a custom loss defined would allow me to experiment with it more thoroughly and make desired changes to it. That being said, I double check whether my custom loss returns similar values as F.binary_cross entropy
by comparing both on 10 samples from the real data. As shown below, the results suggest that the computation is fine, however at the 3 epochs the loss for the custom loss function depreciates to nan
for both discriminator and generator. Before that the loss between F.binary cross entropy
and bce_custom_loss
have similar values. I have been trying to tackle this instability for a couple of days now, without success.
I could imagine that this instability is due to the fact that I do not apply the “log-sum-exp trick”, which pytorch seems to apply in the background as indicated here.
Elaborating on the above,
sigmoid()
is not there, because it is
not explicitly part ofBCE
. It is hiding in the
log (sigmoid())
version of the “log-sum-exp trick,” in this line
from the c++ code that Pytorchtester posted:
loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());
However, I am not sure whether this is the reason for the instability. As I am not familiar with the log-sum-trick, I am also not sure how I could include it.
My second suspect would be the forward pass. However, I make not changes to the forward passes when using F.binary_cross_entropy
and my custom loss. Since it is a question of replacing F.binary_cross_entropy
by my custom loss, which ideally should be identical, I figured this could not be the reason. As such I assume that the forward pass is fine, since it works for the F.binary_cross_entropy
. Here again I am however not certain.
What do you think ? Please let me know about anything that comes to mind as I am quite stuck currently. I would be happy to experiment with your suggestions and have this matter solved.
Kind regards,
Deniz
Note that I clip the output of the log function at -100 as indicated in the pytorch source code.
Our solution is that BCELoss clamps its log function outputs to be greater than
or equal to -100. This way, we can always have a finite loss value and a linear
backward method.
Testing:
def bce_custom_loss(x , y):
loss = torch.mean(-y * torch.clamp(torch.log(x), min = -100) - (1.0-y) * torch.clamp(torch.log(1.0 - x),min = -100))
return loss
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
real_pred= torch.tensor([0.8114, 0.2695, 0.8771, 0.4360, 0.7581, 0.8822, 0.6321, 0.5566, 0.6454,
0.4139]).to(device)
fake_pred=torch.tensor([0.0541, 0.5377, 0.8052, 0.8737, 0.3856, 0.2496, 0.2741, 0.3890, 0.1179,
0.7338]).to(device)
label_real= torch.tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]).to(device)
label_fake= torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]).to(device)
print("------Generator Loss")
print("F.binary_cross_entropy:", F.binary_cross_entropy(fake_pred, label_real))
print("bce_custom_loss:", bce_custom_loss(fake_pred, label_real))
print("------Discriminator Loss")
print("F.binary_cross_entropy:",0.5 * (F.binary_cross_entropy(label_real, real_pred) + F.binary_cross_entropy(label_fake, fake_pred)))
print("bce_custom_loss", 0.5 * (bce_custom_loss(label_real, real_pred) + bce_custom_loss(label_fake, fake_pred)))
Output:
------Generator Loss
F.binary_cross_entropy: tensor(1.0916)
bce_custom_loss: tensor(1.0916)
------Discriminator Loss
F.binary_cross_entropy: tensor(0.5248)
F.binary_cross_entropy: tensor(55.7930)
bce_custom_loss: tensor(0.5248)
bce_custom_loss tensor(55.7930)