VAE doesn't work, but AE does

Hello guys,

I’m facing a problem with my VAE and it’s that it doesn’t properly reconstruct the data.
My process is: UMAP, to see how the data is structured, VAE, and then UMAP over the reconstructed data.
The point is that the UMAP plot from the beginning and the UMAP plot from the reconstructed data are not similar at all. In fact, they’re completely different. That’s why I think my VAE is not working.
Besides, if I run just a AE, it works perfectly fine. And even if I remove the KL-Diverge err, the VAE works.

Does anyone know what could be happening here? It’s so strange…

My code is this:

features = 16

class VAE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

        #encoder layers
        self.encoder1 = nn.Linear(in_features=kwargs["input_shape"], out_features=kwargs["mid_dim"])
        self.encoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=features*2)

        #decoder layers
        self.decoder1 = nn.Linear(in_features=features, out_features=kwargs["mid_dim"])
        self.decoder2 = nn.Linear(in_features=kwargs["mid_dim"], out_features=kwargs["input_shape"])

    def reparametrize(self, mu, log_var):

        # mu: mean of the encoder's latent space distribution
        # log_var: variance from the encoder's latient space distribution
        if self.training:
            std = torch.exp(0.5*log_var) #standard deviation. 0,5 to have a unit variance
            eps = torch.randn_like(std) #same size as std
            sample = mu + (eps*std) #we take a value of the distribution of the latent space
        else:   
            sample = mu
        return sample

    def forward(self, x):
        # encode
        x = F.relu(self.encoder1(x))
        x = self.encoder2(x).view(-1,2,features)

        #get mu and log_var
        mu = x[:, 0, :] # the first feature values as mean
        log_var = x[:, 1, :] # the other feature values as variance

        z = self.reparametrize(mu,log_var) #get a sample of the distribution

        #decode
        x = F.relu(self.decoder1(z))
        reconstruction = torch.sigmoid(self.decoder2(x))
        return reconstruction, mu, log_var, z

## Loss function

def final_loss(mu, logvar, reconstruction_loss):

    KL_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) #Appendix B VAE Paper
    Reconstruction = reconstruction_loss

    return KL_divergence + Reconstruction  

## And the training script

#define a function to train the data
def fit(model, dataloader):
    model.train()
    running_loss = 0.0
    for i, data in tqdm(enumerate(dataloader), total=int(len(train_dataset)/dataloader.batch_size)):
        data, _ = data # we want the data, not the label
        data = data
        data = data.view(data.size(0), -1) #flat the data
        optimizer.zero_grad() # reset the gradients back to zero
        reconstruction, mu, logvar,_ = model(data)  # compute reconstructions
        reconstruction_loss = criterion(reconstruction, data) #calculate reconstruction loss
        loss = final_loss(mu, logvar, reconstruction_loss)# real loss: reconstruction + kl_divergence
        running_loss += loss.item() 
        loss.backward() # compute accumulated gradients
        optimizer.step() #update the weights (net.parameters)
    train_loss = running_loss/len(dataloader.dataset) # average loss
    return train_loss

Thanks!!