Hello!
I’ve trained a stand-alone VAE based on the PyTorch example and a few other bits of code found on github - it works well and my output images look quite good. All of the examples dealt with MNIST but my model uses ImageNet images so it’s a big bigger than the examples. Apart from that, it doesn’t differ too much.
When I then want to use the VAE model in a bigger network where the output of the VAE is fed to a CNN and their losses are summed (there is a reason for that, I promise) the KLD loss of the VAE part is NaN. The VAE and CNN form one model but I return the output of the VAE (assigned to a separate variable before it even enters the CNN half) and calculate my VAE loss using the same function I used in the stand-alone VAE example.
From the very 1st epoch the KLD loss is NaN (BCE loss is fine). Is there a way to stabilise this behaviour? What could make the model work as a stand-alone system but fail when linked with anything else?
The KLD loss tends to be much much bigger in the 1st epoch so I thought it might be a data casting problem and added a scaling factor of 1e-10 but even that didn’t help. I considered setting KLD to some small value in the 1st epoch to prevent the overshoot but then I just get a NaN in the 2nd epoch…
This is my loss function:
def vae_loss(reconstructed_x, x, mean, logvar, batch_size):
BCE = F.binary_cross_entropy(reconstructed_x, x, reduction='sum')
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
#normalise the outputs so divide by spatial dimensions + batch size
BCE /= (3*model_size*model_size*batch_size)
KLD /= (3*model_size*model_size*batch_size)
total = BCE+KLD
return total, BCE, KLD
Then, during training, I say
output, ae_output, mean, logvar = model(input)
cnn_loss = cnn_criterion(output, labels) # cross entropy loss
ae_loss, bce, kld = vae_loss(ae_output, input, mean, logvar, batch_size)
total_loss = ae_loss + cnn_loss
Has anyone dealt with something similar? Any ideas how to fix this?