Implementation of simple VAE loss

Hello, I’m trying to implement a C++ version of this vae example but I have trouble with implementing the KL divergence term. My implementation is as follows:

// do stuff
torch::Tensor mse = torch::mse_loss(recon_batch, batch.data.reshape({recon_batch.size(0), 784}));
torch::Tensor kld = - 0.5 * torch::sum(1 + logvar - mu.pow(2) - logvar.exp());
auto loss = mse + kld;
loss.backward()
// do stuff

The loss though stops decreasing after very few epochs, providing back reconstruction results when compared to the python example, when I’m optimizing only the based on the reconstruction loss it decreases to lower point. Any ideas on what I’m missing? Thanks! :slight_smile: