# Variational Autoencoder loss function: MSE vs BCE

Hello guys!

I need your wisdom and intelligence. I’m working with Variational Autoencoders, but I don’t understand when should I chose MSE or BCE as loss function. As far as I understand, I should pick MSE if I believe that the latent space of the embedding is Gaussian, and BCE if it’s multinomial, is that true?

For instance, I am doing some test with MNIST dataset. If I pick nn.MSELoss, it works terribly wrong. If I pick nn.BCELoss(reduction=‘sum’), it works decently. Does it mean that the latent space distribution is not Gaussian?

My code is this, if you want to try it, there’s a huge difference:

``````import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import matplotlib.pyplot as plt

features = 16

class VAE(nn.Module):
def __init__(self, **kwargs):
super().__init__()

#encoder layers
self.encoder1 = nn.Linear(in_features=kwargs["input_shape"], out_features=kwargs["mid_dim"])
self.encoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=features*2)

#decoder layers
self.decoder1 = nn.Linear(in_features=features, out_features=kwargs["mid_dim"])
self.decoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=kwargs["input_shape"])

def reparametrize(self, mu, log_var):

# mu: mean of the encoder's latent space distribution
# log_var: variance from the encoder's latient space distribution

std = torch.exp(0.5*log_var) #standard deviation. 0,5 to have a unit variance
eps = torch.randn_like(std) #same size as std
sample = mu + (eps*std) #we take a value of the distribution of the latent space
return sample

def forward(self, x):
# encode
x = F.relu(self.encoder1(x))
x = self.encoder2(x).view(-1,2,features)

#get mu and log_var
mu = x[:, 0, :] # the first feature values as mean
log_var = x[:, 1, :] # the other feature values as variance

z = self.reparametrize(mu,log_var) #get a sample of the distribution

#decode
x = F.relu(self.decoder1(z))
reconstruction = torch.sigmoid(self.decoder2(x))
return reconstruction, mu, log_var, z

model = VAE(input_shape=784, mid_dim=512)
criterion = nn.BCELoss(reduction='sum')# MSE or CrossEntropy?
#criterion = nn.MSELoss()

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
)

test_dataset = torchvision.datasets.MNIST(
)

train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)

test_dataset, batch_size=32, shuffle=False, num_workers=4
)

def total_loss(mu, logvar, mse_loss):

KL_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
MSE = mse_loss

return KL_divergence + MSE

epochs = 20

for epoch in range(epochs):
loss = 0
# reshape mini-batch data to [N, 784] matrix
batch_features = batch_features.view(-1, 784)

# reset the gradients back to zero
# PyTorch accumulates gradients on subsequent backward passes

# compute reconstructions
outputs, mu, logvar, code = model(batch_features)

#print("reconstruction: ", outputs.shape)
#print("mu: ", mu.shape)
#print("logvar: ", logvar.shape)
# compute training reconstruction loss
MSE_loss = criterion(outputs, batch_features)
Loss = total_loss(mu, logvar, MSE_loss)

Loss.backward()

# perform parameter update based on current gradients
optimizer.step() #update the weights (net.parameters)

# add the mini-batch training loss to epoch loss
loss += Loss.item()

# compute the epoch training loss

# display the epoch training loss
print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))

reconstructed = outputs.view(-1,1,28,28)
original = batch_features.view(-1,1,28,28)
coded = code.view(-1,1,8,2)
img = T.ToPILImage()(reconstructed) # plot the first element of the last batch
img2 = T.ToPILImage()(original)
img_code = T.ToPILImage()(coded)

plt.subplot(131)
imgplot = plt.imshow(img)
plt.subplot(132)
imgplot2 = plt.imshow(img2)
plt.subplot(133)
imgplot3 = plt.imshow(img_code)
plt.suptitle("Reconstructed vs Original vs Code")
plt.show()
``````

Thank you so much! 