Decoder not learning/no back prop occurring in VAE model

Hi Team,
I have designed a VAE to run on some data I have of 50k+ stellar spectra.

I have been trying to adjust the model and train it for the past week but have noticed that while KL loss improves the reconstruction loss hovers around the same value, and never improves.

When using just the decoder part of the VAE no matter what value you give it from the latent space it always produces the same output, and all training data also has the same output from the VAE.

I have noticed that the gradients and weights are being updated for everything in the encoder section of the VAE but all gradients are None in the decoder and all weights do not change.

I have not found any other posts about this happening to others so was wondering if I could get some help, I am unsure why the gradients are not being updated for the decoder and I believe this is why the VAE is not learning.

My VAE model is listed below,

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        input_dim = 400
        z_dim = 5

        #encoder defs
        self.conv1_e = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=8, stride=2, padding=1)
        self.conv2_e = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=4, stride=2, padding=1) 
        self.linear1_e = nn.Linear(in_features=99, out_features=50)
        self.linear2_e = nn.Linear(in_features=50, out_features=25)
        self.z_mu_e = nn.Linear(in_features=25, out_features=z_dim)
        self.z_var_e = nn.Linear(in_features=25, out_features=z_dim)

        #decoder defs
        self.fc_d = nn.Linear(in_features=z_dim, out_features=25)
        self.linear1_d = nn.Linear(in_features=25, out_features=50)
        self.linear2_d = nn.Linear(in_features=50, out_features=99)
        self.conv1_d = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=4, stride=2, padding=1)
        self.conv2_d = nn.ConvTranspose1d(in_channels=1, out_channels=1, kernel_size=8, stride=2, padding=1)

    def encode(self, x):
        x = F.relu(self.conv1_e(x))
        x = F.relu(self.conv2_e(x))
        x = F.relu(self.linear1_e(x))
        x = F.relu(self.linear2_e(x))
        z_mu = self.z_mu_e(x)
        z_var = self.z_var_e(x)
        return z_mu, z_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        x = self.fc_d(z)
        x = F.relu(self.linear1_d(x))
        x = F.relu(self.linear2_d(x))
        x = x.view(x.size()[0], 1, 99)
        x = F.relu(self.conv1_d(x))
        x = F.relu(self.conv2_d(x))
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)

        return self.decode(z), mu, logvar

and the training is listed here…

lr = 1e-6

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# define vae
model = ml.VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-6)

def objective_loss(y1, y2, ivar, log_z_var, z_mu):

    diff = (y1 - y2)**2 * ivar.numpy()
    recon_loss = np.mean(np.sum(diff, axis=1) / INPUT_DIM)
    kl_loss = 0.5 * torch.sum(torch.exp(log_z_var) + z_mu**2 - 1.0 - log_z_var)
    return recon_loss + kl_loss

def train():
    train_loss = 0

    for i, (wl, fl, iv) in enumerate(train_iterator):

        fl =
        fl = fl.reshape(fl.size()[0], 1, 400)
        fl = fl.float()

        fl_sample, z_mu, log_z_var = model(fl)
        y1 = fl_sample.view(fl.size()[0], 400).detach().numpy()
        y2 = fl.view(fl.size()[0], 400).detach().numpy()

        loss = objective_loss(y1, y2, iv, log_z_var, z_mu)

        g = model.linear2_d.weight
        w =

        train_loss += loss.item()

    return train_loss

Does anyone have any suggestion as to what I could check or try to fix this problem? Or can anyone see straight away in my code why this would be a problem,
Thank you very much :slight_smile:

Hey, Alex. Hope you good.
This specific line got my attention. By all means I´m not an expert , only just a PyTorch beginner. But having spent much of my PyTorch time with VAEs I´ll give it a shout. So:

  1. The way I see it you are forcing your decoder not to reconstruct your loss. Your are transforming y1 and y2 by detaching (detaching the gradient) and then numpyi-ing it (yes, I want to make this word trendy. ) So definitively, there wont be any reconstruction loss in your diff variable, because there is no gradients in diff, and without reconstruction, the VAE will make just effort to make things right by the KL term, a compression of information, no generative aspects of the algo.

  2. Question: What is it fl, fl_sample? And why do you think is necessary to detach() and numpy it? So my suggestion is to let diff, y1 and y2 as tensors.

My best regards,

Thank you so much for your reply and your help,

Answer 1:
I am using my own reconstruction loss calculation because I was unsure how to correctly incorporate weighted inverse variance. which is what ivar is. They are also all detached because I did not know how to perform the MSE*IVAR calculation in the tensor form. How would I perform that calculation without detaching the gradients?

Answer 2:
fl is flux which is my input data and fl_sample is the reconstruction from the VAE.

(I think you are correct, issues are from things not being tensors, but still unsure how to do what I want with them staying as tensors. )

Thank you!

PyTorch most basic working element is a tensor. So, when you are loading your data set through enumerate(train_iterator) you are already dealing with tensors. Basically what you have to do is just not transform into Numpy (erase it detach().numpy()).

Then the underlying problem is how to deal with a custom loss. The simplest way is to create your own custom function inverse and call it inside your code. You just have to rewrite using standard torch operations functions, rather than use numpy operations. Something like this:

def inv_var(y1, y2):
       y_hat = torch.sum( (y_1/y_1.var() ), ( y_2/y_2.var()) ) /  torch.sum( (1 /y_1.var() ), ( 1/y_2.var())  )
return y_hat

Later on you need to write more complex loss functions I suggest: However in this case I don´t see a necessity for it.


That fixed the problem thank you! I have gradients now, but they are very small so onto the next problem haha.

Thank you for your help :slight_smile:


Good to hear it fixed it.
One thing that helps is to plot the evolution of Reconstruction Loss and the KL Loss to have a better assessment of whats going on.
I have also been through some issues in training VAE, as mentioned, and one tricky problem is the issue called " Posterior Collapse". Basically, your decoder ignores latent codes from the encoder and tries to model the data on it´s own (“auto decoding”).

This issue is well documented here: Fixing the Broken ELBO. In addition, I also like these two papers to better understand the whole idea: Deep Variational Information Bottleneck and Learning Sparse Latent Representations with the Deep Copula Information Bottleneck. Also googling “VAE Posterior Collapse” may also point out to other interesting things.

I would suggest the following:

  1. Tweak hyperparameters (learning rate, network architecture) and see how the losses evolve. Maybe the architecture is not power enough and may not be the “Posterior Collapse issue”. However, if so move to point 2 below.
  2. After understanding the “Posterior Collapse” issue, then tweak the losses as pointed out in the papers. For example, one such thing could be add a penalization term beta to KL part.
  3. If somehow you tried all these things, and reconstruction loss may still be weird and there is still the suspicion of posterior collapse you may just change your loss to the Mean Maximum Discrepancy. It has been shown to reduce (or eliminate?) this issue. A proper reference is found here InfoVAE: Balancing Learning and Inference in Variational Autoencoders. And there is a tutorial implementing it here: A Tutorial on Information Maximizing Variational Autoencoders (InfoVAE).

By following the steps above I guess it may help.