Hello everyone!
I’m trying to implement a simple VAE by following several tutorials ([1],https://sannaperzon.medium.com/paper-summary-variational-autoencoders-with-pytorch-implementation-1b4b23b1763a. This is the code that I have written:
class VAE(torch.nn.Module):
def __init__(self,input_shape,h_dim,z_dim):
super(VAE,self).__init__()
#Encoder, image to hidden space
self.img_2hid = Modules.Linear(input_shape,h_dim)
# From the hidden we move to z space and we get mu and std
self.fc_mu,self.fc_std = Modules.Linear(h_dim,z_dim),Modules.Linear(h_dim,z_dim)
# The decoder from the z space goes back to the hidden
self.z_2hidden = Modules.Linear(z_dim,h_dim)
# Finally from the hidden we go back to the input space
self.hidden_2img = Modules.Linear(h_dim,input_shape)
self.relu = Modules.Relu()
self.rloss = torch.nn.MSELoss('sum')
def encode(self,x):
x = self.relu(self.img_2hid(x))
return self.fc_mu(x),torch.log(self.fc_std(x))
def decode(self,x):
x = self.relu(self.z_2hidden(x))
return self.relu(self.hidden_2img(x))
def normalize(self,x):
v_min, v_max = x.min(), x.max()
new_min, new_max = torch.Tensor([0]).float(), torch.Tensor([255]).float()
x = torch.add(torch.mul(torch.div(torch.sub(x, v_min), torch.sub(v_max, v_min)), torch.sub(new_max, new_min)),new_min)
return x
def forward(self,x):
original = x.shape
# encode x to get the mu and variance parameters
mu, log_sigma = self.encode(x.reshape(-1))
# sample z from q
sigma = torch.exp(log_sigma)
epsilon = torch.randn_like(sigma)
z = mu + sigma * epsilon
# decoded
x_hat = self.normalize(self.decode(z)).reshape(original)
# reconstruction loss
recon_loss = self.rloss(x_hat,x)
# kl
# The formula here is suppose for a normal distribution for q(z|x)
kl_div = - 0.5 * torch.sum(1 + torch.log(sigma) * 2 - mu.pow(2) - torch.exp(log_sigma*2))
# elbo
elbo = recon_loss + kl_div
return elbo,recon_loss,kl_div,x_hat
The problem that I’m having is that the output of the decoder is extremely noisy. For reference by training on this image:
I get:
Any idea why? On the MNIST dataset the VAE works well, but as soon as I pass real images it behaves this way