I have been trying to replace
F.binary_cross_entropy by my own binary cross entropy custom loss since I want to adapt it and make appropriate changes. I feel that having it as a custom loss defined would allow me to experiment with it more thoroughly and make desired changes to it. That being said, I double check whether my custom loss returns similar values as
F.binary_cross entropy by comparing both on 10 samples from the real data. As shown below, the results suggest that the computation is fine, however at the 3 epochs the loss for the custom loss function depreciates to
nan for both discriminator and generator. Before that the loss between
F.binary cross entropy and
bce_custom_loss have similar values. I have been trying to tackle this instability for a couple of days now, without success.
I could imagine that this instability is due to the fact that I do not apply the “log-sum-exp trick”, which pytorch seems to apply in the background as indicated here.
Elaborating on the above,
sigmoid()is not there, because it is
not explicitly part of
BCE. It is hiding in the
log (sigmoid())version of the “log-sum-exp trick,” in this line
from the c++ code that Pytorchtester posted:
loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());
However, I am not sure whether this is the reason for the instability. As I am not familiar with the log-sum-trick, I am also not sure how I could include it.
My second suspect would be the forward pass. However, I make not changes to the forward passes when using
F.binary_cross_entropy and my custom loss. Since it is a question of replacing
F.binary_cross_entropy by my custom loss, which ideally should be identical, I figured this could not be the reason. As such I assume that the forward pass is fine, since it works for the
F.binary_cross_entropy. Here again I am however not certain.
What do you think ? Please let me know about anything that comes to mind as I am quite stuck currently. I would be happy to experiment with your suggestions and have this matter solved.
Note that I clip the output of the log function at -100 as indicated in the pytorch source code.
Our solution is that BCELoss clamps its log function outputs to be greater than
or equal to -100. This way, we can always have a finite loss value and a linear
def bce_custom_loss(x , y): loss = torch.mean(-y * torch.clamp(torch.log(x), min = -100) - (1.0-y) * torch.clamp(torch.log(1.0 - x),min = -100)) return loss device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) real_pred= torch.tensor([0.8114, 0.2695, 0.8771, 0.4360, 0.7581, 0.8822, 0.6321, 0.5566, 0.6454, 0.4139]).to(device) fake_pred=torch.tensor([0.0541, 0.5377, 0.8052, 0.8737, 0.3856, 0.2496, 0.2741, 0.3890, 0.1179, 0.7338]).to(device) label_real= torch.tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]).to(device) label_fake= torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]).to(device) print("------Generator Loss") print("F.binary_cross_entropy:", F.binary_cross_entropy(fake_pred, label_real)) print("bce_custom_loss:", bce_custom_loss(fake_pred, label_real)) print("------Discriminator Loss") print("F.binary_cross_entropy:",0.5 * (F.binary_cross_entropy(label_real, real_pred) + F.binary_cross_entropy(label_fake, fake_pred))) print("bce_custom_loss", 0.5 * (bce_custom_loss(label_real, real_pred) + bce_custom_loss(label_fake, fake_pred)))
------Generator Loss F.binary_cross_entropy: tensor(1.0916) bce_custom_loss: tensor(1.0916) ------Discriminator Loss F.binary_cross_entropy: tensor(0.5248) F.binary_cross_entropy: tensor(55.7930) bce_custom_loss: tensor(0.5248) bce_custom_loss tensor(55.7930)