Variational Autoencoder loss function: MSE vs BCE

Hello guys!

I need your wisdom and intelligence. I’m working with Variational Autoencoders, but I don’t understand when should I chose MSE or BCE as loss function. As far as I understand, I should pick MSE if I believe that the latent space of the embedding is Gaussian, and BCE if it’s multinomial, is that true?

For instance, I am doing some test with MNIST dataset. If I pick nn.MSELoss, it works terribly wrong. If I pick nn.BCELoss(reduction=‘sum’), it works decently. Does it mean that the latent space distribution is not Gaussian?

My code is this, if you want to try it, there’s a huge difference:

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import matplotlib.pyplot as plt

features = 16

class VAE(nn.Module):
    def __init__(self, **kwargs):

        #encoder layers
        self.encoder1 = nn.Linear(in_features=kwargs["input_shape"], out_features=kwargs["mid_dim"])
        self.encoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=features*2)

        #decoder layers
        self.decoder1 = nn.Linear(in_features=features, out_features=kwargs["mid_dim"])
        self.decoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=kwargs["input_shape"])

    def reparametrize(self, mu, log_var):

        # mu: mean of the encoder's latent space distribution
        # log_var: variance from the encoder's latient space distribution

        std = torch.exp(0.5*log_var) #standard deviation. 0,5 to have a unit variance
        eps = torch.randn_like(std) #same size as std
        sample = mu + (eps*std) #we take a value of the distribution of the latent space
        return sample

    def forward(self, x):
        # encode
        x = F.relu(self.encoder1(x))
        x = self.encoder2(x).view(-1,2,features)

        #get mu and log_var
        mu = x[:, 0, :] # the first feature values as mean
        log_var = x[:, 1, :] # the other feature values as variance

        z = self.reparametrize(mu,log_var) #get a sample of the distribution

        x = F.relu(self.decoder1(z))
        reconstruction = torch.sigmoid(self.decoder2(x))
        return reconstruction, mu, log_var, z

model = VAE(input_shape=784, mid_dim=512)
criterion = nn.BCELoss(reduction='sum')# MSE or CrossEntropy?
#criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=True

test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True

train_loader =
    train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True

test_loader =
    test_dataset, batch_size=32, shuffle=False, num_workers=4

def total_loss(mu, logvar, mse_loss):

    KL_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    MSE = mse_loss

    return KL_divergence + MSE   

epochs = 20

for epoch in range(epochs):
    loss = 0
    for batch_features, _ in train_loader:
        # reshape mini-batch data to [N, 784] matrix
        batch_features = batch_features.view(-1, 784)
        # reset the gradients back to zero
        # PyTorch accumulates gradients on subsequent backward passes
        # compute reconstructions
        outputs, mu, logvar, code = model(batch_features)
        #print("reconstruction: ", outputs.shape)
        #print("mu: ", mu.shape)
        #print("logvar: ", logvar.shape)       
        # compute training reconstruction loss
        MSE_loss = criterion(outputs, batch_features)
        Loss = total_loss(mu, logvar, MSE_loss)
        # compute accumulated gradients
        # perform parameter update based on current gradients
        optimizer.step() #update the weights (net.parameters)
        # add the mini-batch training loss to epoch loss
        loss += Loss.item()
    # compute the epoch training loss
    loss = loss / len(train_loader)
    # display the epoch training loss
    print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))

    reconstructed = outputs.view(-1,1,28,28)
    original = batch_features.view(-1,1,28,28)
    coded = code.view(-1,1,8,2)
    img = T.ToPILImage()(reconstructed[0]) # plot the first element of the last batch
    img2 = T.ToPILImage()(original[0])
    img_code = T.ToPILImage()(coded[0])

    imgplot = plt.imshow(img)
    imgplot2 = plt.imshow(img2)
    imgplot3 = plt.imshow(img_code)
    plt.suptitle("Reconstructed vs Original vs Code")

Thank you so much! :slight_smile: