Correct Implementation of Beta-VAE Reconstruction Loss with ViT Encoder-Decoder Architecture

I am implementing a Vision Transformer (ViT) Encoder-Decoder architecture trained within a Beta VAE framework on noisy latent codes. My encoder returns the learned mean (mu ) and log variance (logvar ) of the latent space. Despite various attempts, I haven’t achieved the expected results for my decoded samples, yet. I suspect that the issue might also be in the definition of my objective, particularly the reconstruction loss. I have come across various ways of implementing the reconstruction loss in different works and am unsure if my current approach is correct.

Current Implementation :
I am using F.mse_loss with reduction='none' to get the pixel-wise error, then summing the errors over the height, width, and channel dimensions to get the loss per sample. I noticed an improvement when summing over the height, width, and channel dimensions in the variance of my decoded samples. Additionally, I have been experimenting with implementing the Negative Log-Likelihood (NLL) Loss, as I am working with noisy (normally distributed) latent codes of shape (B, 4, 32, 32).

Questions:

  1. Is the reconstruction loss correctly implemented for a ViT decoder?
  2. Are there any potential issues or improvements that I could consider?

Code Sample:

    def reparametrize(self, mu, logvar):
        """Reparametrization trick to sample latent variable z ~ N(mu, sigma^2)."""
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)  # Sample from standard normal distribution
            return eps * std + mu
        else:
            return mu

    def reconstruction_loss(self, x, recon_x):
        """Reconstruction loss."""
        if self.recon_loss_type == 'L2':
            # Standard MSE Loss (L2, pixel-wise no uncertainty)
            mse_loss = F.mse_loss(recon_x, x, reduction='none')  # Pixel-wise error
            recon_loss = mse_loss.sum(dim=(1, 2, 3))  # Sum over all samples when (B, C, H, W)
        elif self.recon_loss_type == 'NLL':
            # Negative Log-Likelihood with learnable scale (models aleatoric uncertainty)
            var = torch.exp(self.scale)  # Convert to variance (log ensures positive values)
            dist_normal = torch.distributions.Normal(recon_x, torch.sqrt(var))  # Learned variance
            log_pxz = dist_normal.log_prob(x)
            recon_loss = -log_pxz.sum(dim=(1, 2, 3))  # Sum over all samples when (B, C, H, W)
        else:
            raise ValueError('Undefined reconstruction loss type.')

        return torch.mean(recon_loss, dim=0)  # Mean over the batch dimension

    def kl_divergence(self, mu, logvar):
        """KL divergence between approximate posterior and prior."""
        kld_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1), dim=0)
        return kld_loss