Unable to train VAE Decoder

Hello Everyone,
I’m hesitant posting this here as I’m not sure if this forum is open to debugging questions. Nevertheless, I’m willing to take my chances.

I have followed this tutorial right here to build my own VAE, which I train on my custom data set with the same type of images as MNIST which is used on the tutorial(28x28, black and white).

However, I must have done something wrong while implementing it, since my model won’t learn anything. It just outputs a specific pattern for each input image ,as you can see below, even though the input images have huge differences.

This problem persists and the model insists even more so on this random pattern and just disregards the input.

I’ve included my code below. I can say that I have verified a couple things about my implementation after seeing people having similar issues. I’m pretty sure all my gradients are updating. I also checked to see where this pattern is coming from and it seems like the encoder part is doing its job really well i.e. it outputs different encodings for different images. For some reason, the decoder just interprets these different encodings to very similar output images.

Any kind of suggestion would be much appreciated, and I am ready to supply any kind of debugging information you might request. Thank you in advance
Encoder

 class VariationalEncoder(nn.Module):
    def __init__(self, latent_dimensions):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16,32,3, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(start_dim=1),
            nn.Linear(3 * 3 * 32, 128),
            nn.ReLU(),
        )

        self.linear1 = nn.Linear(128, latent_dimensions)
        self.linear2 = nn.Linear(128, latent_dimensions)

        self.N = torch.distributions.Normal(0, 1)
        self.KLDivergence = 0

    def forward(self, x):
        out = self.model(x)

        mean = self.linear1(out)
        stddev = torch.exp(self.linear2(out))

        #  mu =  self.linear2(x)
        # sigma = torch.exp(self.linear3(x))
        # z = mu + sigma*self.N.sample(mu.shape)
        # self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        
        z = mean + stddev * self.N.sample(stddev.shape)

        self.KLDivergence =  (stddev**2 + mean**2 - torch.log(stddev) - 1/2).sum()
        return z

Decoder

class Decoder(nn.Module):
    def __init__(self, latent_dimensions):
        super().__init__()

        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dimensions, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32),
            nn.ReLU(True),
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, padding=1, output_padding=1),
        )

    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        x = self.decoder_conv(x)
        x = torch.sigmoid(x)

        return x

Parent Class

class VariationalAutoEncoder(nn.Module):
    def __init__(self, latent_dimensions):
        super().__init__()
        self.encoder = VariationalEncoder(latent_dimensions)
        self.decoder = Decoder(latent_dimensions)

        # set default floating point data type to float64

        torch.set_default_dtype(torch.float64)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

Training Loop

def train_epoch(vae, dataloader, optimizer):
    # Set train mode

    vae.train()
    vae.encoder.train()
    vae.decoder.train()
    train_loss = 0.0

    # set_trace()

    for x in tqdm(dataloader):
        # Get model output

        optimizer.zero_grad()
        # x = torch.from_numpy(x)
        x_hat = vae(x)

        # Evaluate loss
        loss = torch.sum((x - x_hat) ** 2) + vae.encoder.KLDivergence

        # Backward pass
        loss.backward()
        optimizer.step()

        # Print batch loss
        # print(f"\t partial training loss (single batch): {loss.item()} ")
        train_loss += loss.item()

    return train_loss / len(dataloader.dataset)

Thank you for your time!

A good first debugging step would be to try to overfit a small subset of the dataset (e.g. use only 10 samples) and check if your model is able to train on this dataset by playing around with some hyperparameters.
If that’s not possible your training code might have other bugs which break the training.

Thank you for your quick response! I tried what you recommended and trained on only 10 samples and ran 1000 epochs to make sure that I overfitted. Here’s what the model outputted for those 10 samples.

I just realized I didn’t include some of the information that you mentioned so I just wanted to add the remaining of my training code.

def train():
    # Set rand number seed for reproducible results

    torch.manual_seed(0)
    dataset = PopulationOneLineDataset("../pop_data")

    transform = transforms.Compose([transforms.ToTensor()])

    dataset.transform = transform

    m = len(dataset)

    # train_data, test_dataset = random_split(dataset, [m - int(m * 0.2), int(m * 0.2)])
    batch_size = 256

    train_loader = DataLoader(dataset, batch_size=batch_size,)
    # eval_loader = DataLoader(test_dataset, batch_size=batch_size,shuffle=True)

    # torch.manual_seed(0)
    d = 3
    vae = VariationalAutoEncoder(latent_dimensions=d)

    learning_rate = 1e-4

    optim = torch.optim.Adam(vae.parameters(), lr=learning_rate)

    training_loss = []

    num_epochs = 1000

    for epoch in tqdm(range(num_epochs)):
        train_loss = train_epoch(vae, train_loader, optim)
        training_loss.append(train_loss)
        # eval_loss = test_epoch(vae, eval_loader)
        # print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,eval_loss))

    plot_ae_outputs(vae.encoder, vae.decoder, dataset,n=10)
    torch.save(vae, "vae.model")

    plt.plot(range(len(training_loss)),training_loss)
    plt.show()

Using your code I can properly overfit a single sample of MNIST:

device = "cpu"
model = VariationalAutoEncoder(64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


dataset = datasets.MNIST(root="~/python/data", download=False, transform=transforms.ToTensor())
x, y = dataset[0]
x = x.to(device)
x = x.unsqueeze(0)

for epoch in range(1000):
    optimizer.zero_grad()
    output = model(x)
    loss = torch.sum((x - output) ** 2) + model.encoder.KLDivergence
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(f"epoch: {epoch}, loss: {loss.item():.3}")
        plt.imshow(output[0, 0].detach().cpu().numpy())

Resulting in:
image
to
image

The KLDivergence loss part seems to be stuck but I don’t know if it could be related to your issue in not being able to overfit a simple test case.

I see, this makes me feel like something might be wrong with my dataset. I will look into it and report back here if I find something. Thank you so much for your time and efforts!

Hi Deniz,

I’m not sure if this could be your issue, but a phenomenon called ‘posterior collapse’ can happen with the original VAE. The ELBO objective amounts to - D[q(x,z) || p(x,z)] (see eq 2 of InfoVAE paper) in the section “Equivalent Forms of the ELBO objective”. What then happens is that during training, the VAE enters a very bad local minimum and cannot recover. The local minimum is characterized by the decoder ignoring z (gradients w.r.t. z go to zero) and simply tries to fit the distribution of the dataset unconditionally.

This phenomenon is not well understood (by me!) but it surely depends on the structure of the dataset. So, it is entirely possible that it would happen with your dataset but not with MNIST.

But did you get it to work on the MNIST dataset, just to make sure you implemented everything correctly according to the tutorial?

Hi Henry,
Thanks for taking the time to answer my question. I will look into this posterior collapse phenomenon, but I haven’t really worked on that dataset for a while. I was able to get it working on MNIST with the exact same code (only the data loading part of the code was different).

Best,
Deniz