ELBO loss in PyTorch

I’ve read that when data is binary, the reconstruction loss is modeled by a multivariate factorized Bernoulli distribution using torch.nn.functional.binary_cross_entropy, so the ELBO loss can be implemented like this:

def loss_function(recon_x, x, mu, logvar):
   BCE = F.binary_cross_entropy(recon_x, x.view(-1, patch_size*patch_size), reduction='sum')
   KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
   return BCE + KLD

My data is not binary, how can I implement the elbo loss correctly so that it can be converge to zero?

1 Like

you can express ELBO as logP(X) - KL(Q||P), and torch.distributions has relevant density & KL formulas, but you must select a distribution type first, e.g. multivariate normal or Independent(Univariate) (wrapper class).
ps I believe this doesn’t converge to zero for non-binary distributions. and also, loss should be -ELBO.

1 Like

How can I know what distribution I should use? I’m using a dataset that contains gray-scale images and my example code doesn’t work property with it. I’ve read that with continuous data then a diagonal Gaussian distribution may be appropriate, but I I can’t find an implementation. Plus, how can I do to discover if my data are continuous?

1 Like

In VAE context, KL is computed for latent z, acting as a regularizer. In that scenario, the simplest distribution can be used as a prior - standard independent gaussian.

Your snippet already computes std. gaussian KL, and you only have to switch to some loss based on pixel color distances (seems that simple MSE can be used for this).

So do you advise me to use something similar?

import torch.nn.functional as F
def loss_function(recon_x, x, mu, logvar):
   MSE = F.mse_loss(recon_x, x, reduction='mean')
   KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
   return MSE + KLD

yes, something like this, you can compare with PyTorch-VAE/vanilla_vae.py at master · AntixK/PyTorch-VAE · GitHub, I think you have to use reduce KL too (take mean).

I’ve tried this solution but the reconstructed images are totally gray, so the network doesn’t learn. How can I do?

can’t say for sure, maybe your LR is too big. you can try to disabling KL component or reproducing example datasets to narrow down the problem.

Hope you already solve it.
I’m currently learning the VAE and have some comments for your question!
I think it’s because you use

MSE = F.mse_loss(recon_x, x, reduction='mean')

You may try

MSE = F.mse_loss(recon_x, x, reduction='sum')

As you did for BCE.

If you use MSE for mean but KLD for sum, the KLD value will usually be extremely larger than MSE value.
So the model will try to fix the very larger loss from KLD.
If you print the mean and standard deviation out from the encoder after you feed a sample to VAE.
You will likely find that mu=0 and std=1 for all element. And hence you will get the same image no matter which original image you are trying to input to VAE.