Hello guys!
I need your wisdom and intelligence. I’m working with Variational Autoencoders, but I don’t understand when should I chose MSE or BCE as loss function. As far as I understand, I should pick MSE if I believe that the latent space of the embedding is Gaussian, and BCE if it’s multinomial, is that true?
For instance, I am doing some test with MNIST dataset. If I pick nn.MSELoss, it works terribly wrong. If I pick nn.BCELoss(reduction=‘sum’), it works decently. Does it mean that the latent space distribution is not Gaussian?
My code is this, if you want to try it, there’s a huge difference:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import matplotlib.pyplot as plt
features = 16
class VAE(nn.Module):
def __init__(self, **kwargs):
super().__init__()
#encoder layers
self.encoder1 = nn.Linear(in_features=kwargs["input_shape"], out_features=kwargs["mid_dim"])
self.encoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=features*2)
#decoder layers
self.decoder1 = nn.Linear(in_features=features, out_features=kwargs["mid_dim"])
self.decoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=kwargs["input_shape"])
def reparametrize(self, mu, log_var):
# mu: mean of the encoder's latent space distribution
# log_var: variance from the encoder's latient space distribution
std = torch.exp(0.5*log_var) #standard deviation. 0,5 to have a unit variance
eps = torch.randn_like(std) #same size as std
sample = mu + (eps*std) #we take a value of the distribution of the latent space
return sample
def forward(self, x):
# encode
x = F.relu(self.encoder1(x))
x = self.encoder2(x).view(-1,2,features)
#get mu and log_var
mu = x[:, 0, :] # the first feature values as mean
log_var = x[:, 1, :] # the other feature values as variance
z = self.reparametrize(mu,log_var) #get a sample of the distribution
#decode
x = F.relu(self.decoder1(z))
reconstruction = torch.sigmoid(self.decoder2(x))
return reconstruction, mu, log_var, z
model = VAE(input_shape=784, mid_dim=512)
criterion = nn.BCELoss(reduction='sum')# MSE or CrossEntropy?
#criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=True, transform=transform, download=True
)
test_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=False, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=32, shuffle=False, num_workers=4
)
def total_loss(mu, logvar, mse_loss):
KL_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
MSE = mse_loss
return KL_divergence + MSE
epochs = 20
for epoch in range(epochs):
loss = 0
for batch_features, _ in train_loader:
# reshape mini-batch data to [N, 784] matrix
batch_features = batch_features.view(-1, 784)
# reset the gradients back to zero
# PyTorch accumulates gradients on subsequent backward passes
optimizer.zero_grad()
# compute reconstructions
outputs, mu, logvar, code = model(batch_features)
#print("reconstruction: ", outputs.shape)
#print("mu: ", mu.shape)
#print("logvar: ", logvar.shape)
# compute training reconstruction loss
MSE_loss = criterion(outputs, batch_features)
Loss = total_loss(mu, logvar, MSE_loss)
# compute accumulated gradients
Loss.backward()
# perform parameter update based on current gradients
optimizer.step() #update the weights (net.parameters)
# add the mini-batch training loss to epoch loss
loss += Loss.item()
# compute the epoch training loss
loss = loss / len(train_loader)
# display the epoch training loss
print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs, loss))
reconstructed = outputs.view(-1,1,28,28)
original = batch_features.view(-1,1,28,28)
coded = code.view(-1,1,8,2)
img = T.ToPILImage()(reconstructed[0]) # plot the first element of the last batch
img2 = T.ToPILImage()(original[0])
img_code = T.ToPILImage()(coded[0])
plt.subplot(131)
imgplot = plt.imshow(img)
plt.subplot(132)
imgplot2 = plt.imshow(img2)
plt.subplot(133)
imgplot3 = plt.imshow(img_code)
plt.suptitle("Reconstructed vs Original vs Code")
plt.show()
Thank you so much!