I am implementing a Vision Transformer (ViT) Encoder-Decoder architecture trained within a Beta VAE framework on noisy latent codes. My encoder returns the learned mean (mu
) and log variance (logvar
) of the latent space. Despite various attempts, I haven’t achieved the expected results for my decoded samples, yet. I suspect that the issue might also be in the definition of my objective, particularly the reconstruction loss. I have come across various ways of implementing the reconstruction loss in different works and am unsure if my current approach is correct.
Current Implementation :
I am using F.mse_loss
with reduction='none'
to get the pixel-wise error, then summing the errors over the height, width, and channel dimensions to get the loss per sample. I noticed an improvement when summing over the height, width, and channel dimensions in the variance of my decoded samples. Additionally, I have been experimenting with implementing the Negative Log-Likelihood (NLL) Loss, as I am working with noisy (normally distributed) latent codes of shape (B, 4, 32, 32)
.
Questions:
- Is the reconstruction loss correctly implemented for a ViT decoder?
- Are there any potential issues or improvements that I could consider?
Code Sample:
def reparametrize(self, mu, logvar):
"""Reparametrization trick to sample latent variable z ~ N(mu, sigma^2)."""
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std) # Sample from standard normal distribution
return eps * std + mu
else:
return mu
def reconstruction_loss(self, x, recon_x):
"""Reconstruction loss."""
if self.recon_loss_type == 'L2':
# Standard MSE Loss (L2, pixel-wise no uncertainty)
mse_loss = F.mse_loss(recon_x, x, reduction='none') # Pixel-wise error
recon_loss = mse_loss.sum(dim=(1, 2, 3)) # Sum over all samples when (B, C, H, W)
elif self.recon_loss_type == 'NLL':
# Negative Log-Likelihood with learnable scale (models aleatoric uncertainty)
var = torch.exp(self.scale) # Convert to variance (log ensures positive values)
dist_normal = torch.distributions.Normal(recon_x, torch.sqrt(var)) # Learned variance
log_pxz = dist_normal.log_prob(x)
recon_loss = -log_pxz.sum(dim=(1, 2, 3)) # Sum over all samples when (B, C, H, W)
else:
raise ValueError('Undefined reconstruction loss type.')
return torch.mean(recon_loss, dim=0) # Mean over the batch dimension
def kl_divergence(self, mu, logvar):
"""KL divergence between approximate posterior and prior."""
kld_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim=1), dim=0)
return kld_loss