KLD loss goes NaN during VAE training


I’ve trained a stand-alone VAE based on the PyTorch example and a few other bits of code found on github - it works well and my output images look quite good. All of the examples dealt with MNIST but my model uses ImageNet images so it’s a big bigger than the examples. Apart from that, it doesn’t differ too much.

When I then want to use the VAE model in a bigger network where the output of the VAE is fed to a CNN and their losses are summed (there is a reason for that, I promise) the KLD loss of the VAE part is NaN. The VAE and CNN form one model but I return the output of the VAE (assigned to a separate variable before it even enters the CNN half) and calculate my VAE loss using the same function I used in the stand-alone VAE example.

From the very 1st epoch the KLD loss is NaN (BCE loss is fine). Is there a way to stabilise this behaviour? What could make the model work as a stand-alone system but fail when linked with anything else?

The KLD loss tends to be much much bigger in the 1st epoch so I thought it might be a data casting problem and added a scaling factor of 1e-10 but even that didn’t help. I considered setting KLD to some small value in the 1st epoch to prevent the overshoot but then I just get a NaN in the 2nd epoch…

This is my loss function:

def vae_loss(reconstructed_x, x, mean, logvar, batch_size):
    BCE = F.binary_cross_entropy(reconstructed_x, x, reduction='sum') 
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    #normalise the outputs so divide by spatial dimensions + batch size
    BCE /= (3*model_size*model_size*batch_size)
    KLD /= (3*model_size*model_size*batch_size)
    total = BCE+KLD
    return total, BCE, KLD

Then, during training, I say

  output, ae_output, mean, logvar = model(input)
  cnn_loss = cnn_criterion(output, labels) # cross entropy loss
  ae_loss, bce, kld = vae_loss(ae_output, input, mean, logvar, batch_size)
  total_loss = ae_loss + cnn_loss

Has anyone dealt with something similar? Any ideas how to fix this?

I have met the same problem when using VAE for representation learning.
To solve this problem, you must be know what lead to nan during the training process. I think the logvar.exp() in the following fomula lead to overflow in the running process

KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())

so, we need to limit logvar in a specific range by some means.
So, you can initialize weights of VAE in a small range, which is recommended in the range [-0.08, 0.08] to make sure the logvar is small, thus exp can not lead to overflow numerically.
Plus, if you are using CNN for image classification, you need to make image zero-means and unit variance.

For details, you can visit this page. Best wishes.
Use of exp in the Reparametrization


To solve this problem, you must be know what lead to nan during the training process. I think the logvar.exp() in the following fomula lead to overflow in the running process

could also be this part

BCE = F.binary_cross_entropy(reconstructed_x, x, reduction='sum'

could you provide your exception?

I can second this, I clip log var value and also just add an assertion to detect NaN values

1 Like