KLD loss goes NaN during VAE training

Hello!

I’ve trained a stand-alone VAE based on the PyTorch example and a few other bits of code found on github - it works well and my output images look quite good. All of the examples dealt with MNIST but my model uses ImageNet images so it’s a big bigger than the examples. Apart from that, it doesn’t differ too much.

When I then want to use the VAE model in a bigger network where the output of the VAE is fed to a CNN and their losses are summed (there is a reason for that, I promise) the KLD loss of the VAE part is NaN. The VAE and CNN form one model but I return the output of the VAE (assigned to a separate variable before it even enters the CNN half) and calculate my VAE loss using the same function I used in the stand-alone VAE example.

From the very 1st epoch the KLD loss is NaN (BCE loss is fine). Is there a way to stabilise this behaviour? What could make the model work as a stand-alone system but fail when linked with anything else?

The KLD loss tends to be much much bigger in the 1st epoch so I thought it might be a data casting problem and added a scaling factor of 1e-10 but even that didn’t help. I considered setting KLD to some small value in the 1st epoch to prevent the overshoot but then I just get a NaN in the 2nd epoch…

This is my loss function:

def vae_loss(reconstructed_x, x, mean, logvar, batch_size):
    BCE = F.binary_cross_entropy(reconstructed_x, x, reduction='sum') 
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    
    #normalise the outputs so divide by spatial dimensions + batch size
    BCE /= (3*model_size*model_size*batch_size)
    KLD /= (3*model_size*model_size*batch_size)
    total = BCE+KLD
    return total, BCE, KLD

Then, during training, I say

  output, ae_output, mean, logvar = model(input)
  cnn_loss = cnn_criterion(output, labels) # cross entropy loss
  ae_loss, bce, kld = vae_loss(ae_output, input, mean, logvar, batch_size)
  total_loss = ae_loss + cnn_loss

Has anyone dealt with something similar? Any ideas how to fix this?

I have met the same problem when using VAE for representation learning.
To solve this problem, you must be know what lead to nan during the training process. I think the logvar.exp() in the following fomula lead to overflow in the running process

KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())

so, we need to limit logvar in a specific range by some means.
So, you can initialize weights of VAE in a small range, which is recommended in the range [-0.08, 0.08] to make sure the logvar is small, thus exp can not lead to overflow numerically.
Plus, if you are using CNN for image classification, you need to make image zero-means and unit variance.

For details, you can visit this page. Best wishes.
Use of exp in the Reparametrization

2 Likes

To solve this problem, you must be know what lead to nan during the training process. I think the logvar.exp() in the following fomula lead to overflow in the running process

could also be this part


BCE = F.binary_cross_entropy(reconstructed_x, x, reduction='sum'

could you provide your exception?

I can second this, I clip log var value and also just add an assertion to detect NaN values