ELBO loss in PyTorch

I’ve read that when data is binary, the reconstruction loss is modeled by a multivariate factorized Bernoulli distribution using torch.nn.functional.binary_cross_entropy, so the ELBO loss can be implemented like this:

def loss_function(recon_x, x, mu, logvar):
   BCE = F.binary_cross_entropy(recon_x, x.view(-1, patch_size*patch_size), reduction='sum')
   KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
   return BCE + KLD

My data is not binary, how can I implement the elbo loss correctly so that it can be converge to zero?

you can express ELBO as logP(X) - KL(Q||P), and torch.distributions has relevant density & KL formulas, but you must select a distribution type first, e.g. multivariate normal or Independent(Univariate) (wrapper class).
ps I believe this doesn’t converge to zero for non-binary distributions. and also, loss should be -ELBO.

How can I know what distribution I should use? I’m using a dataset that contains gray-scale images and my example code doesn’t work property with it. I’ve read that with continuous data then a diagonal Gaussian distribution may be appropriate, but I I can’t find an implementation. Plus, how can I do to discover if my data are continuous?

In VAE context, KL is computed for latent z, acting as a regularizer. In that scenario, the simplest distribution can be used as a prior - standard independent gaussian.

Your snippet already computes std. gaussian KL, and you only have to switch to some loss based on pixel color distances (seems that simple MSE can be used for this).

So do you advise me to use something similar?

import torch.nn.functional as F
def loss_function(recon_x, x, mu, logvar):
   MSE = F.mse_loss(recon_x, x, reduction='mean')
   KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
   return MSE + KLD

yes, something like this, you can compare with PyTorch-VAE/vanilla_vae.py at master · AntixK/PyTorch-VAE · GitHub, I think you have to use reduce KL too (take mean).

I’ve tried this solution but the reconstructed images are totally gray, so the network doesn’t learn. How can I do?

can’t say for sure, maybe your LR is too big. you can try to disabling KL component or reproducing example datasets to narrow down the problem.