No sample variety on MNIST for pl_bolts.models.autoencoders.VAE

I hope this is OK to post here, as it’s not directly related to PyTorch. I’ll delete it if it doesn’t fit.
I am trying to train a variational autoencoder on small (64x64) grayscale patches. Ultimately, the downstream task is classification, but as most of my data is not labelled, semi-supervised learning seem like an interesting solution.
I am using PyTorch lightning, but training a VAE on my images lead to absolutely 0 sample variety. I have the following code :

import os 
from torch.nn import Conv2d
from pytorch_lightning import Trainer
from pl_bolts.models.autoencoders import VAE 
from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.callbacks.variational import LatentDimInterpolator as LDI

# interpolator callback  

interpolator = LDI(interpolate_epoch_interval = 1)

# training 

dataset = MNISTDataModule(num_workers = 8, batch_size = 128)
model = VAE(input_height = 28, enc_type = "resnet18", kl_coeff=1)

# changing first and last convolution to deal with 28x28 grayscale images 

model.encoder.conv1 = Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.decoder.conv1 = Conv2d(64*model.decoder.expansion, 1, kernel_size=3, stride=1, padding=3, bias=False)
print(model)

# fit model 

trainer = Trainer(gpus=1,  max_epochs=50, callbacks = [interpolator], default_root_dir = os.getcwd())
trainer.fit(model, dataset)
trainer.save_checkpoint(os.getcwd())

I only modified the first and last convolution of the resnet decoder and encoder, because the images only have 1 channel. There are some small modifications in the callback code (I added a save directory to see the images as I’m not using tensorboard, and I changed a self.module call into self.module.decoder() as it was not running before).

This code gives me the following result (50 epochs, lr = 0.0001, Adam optimizer) :
3

On my dataset (grayscale chromosome patches) I get similar results :

This post (How to fix a Variational Autoencoder (VAE) that suffers from mode collapse - Quora) suggests that loss balancing could be an issue, but an hyperparameter search for kl_coeff didn’t turn anything up.

However, a fully connected VAE gives much better results, both on MNIST and on my chromosome dataset :

class encoder(nn.Module):

    def __init__(self, img_height, h_dim1, h_dim2, **kwargs) -> None:

        super(encoder, self).__init__()

        self.fc1 = nn.Linear(img_height*img_height, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

class decoder(nn.Module):

    def __init__(self, z_dim, h_dim1, h_dim2, img_height):

        super(decoder, self).__init__()

        self.img_height = img_height
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, img_height*img_height)

        
    def forward(self, x):
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = F.sigmoid(self.fc6(x))
        x = x.view(-1, 1, self.img_height, self.img_height)
        return x



class VAE(pl.LightningModule):

    """
    Standard VAE with Gaussian Prior and approx posterior.
    Model is available pretrained on different datasets:
    """
    
    def __init__(
        self,
        img_height:int = 28,
        h_dim1:int = 512,
        h_dim2:int = 256,
        latent_dim:int = 2,
        kl_coeff: float = 0.1,
        lr: float = 1e-4,
        **kwargs
    ):
        """
        Args:
            input_height: height of the images
            enc_type: option between resnet18 or resnet50
            enc_out_dim: set according to the out_channel count of
                encoder used (512 for resnet18, 2048 for resnet50)
            kl_coeff: coefficient for kl term of the loss
            latent_dim: dim of latent space
            lr: learning rate for Adam
        """

        super(VAE, self).__init__()

        self.save_hyperparameters()

        self.lr = lr
        self.kl_coeff = kl_coeff
        self.latent_dim = latent_dim
        self.h_dim1 = h_dim1
        self.h_dim2 = h_dim2
        self.img_height = img_height

        self.fc_mu = nn.Linear(self.h_dim2, self.latent_dim)
        self.fc_var = nn.Linear(self.h_dim2, self.latent_dim)

        self.encoder = encoder(self.img_height, self.h_dim1, self.h_dim2)
        self.decoder = decoder(self.latent_dim, self.h_dim1, self.h_dim2, self.img_height)

The following training code is the same as the one in pl_bolts.models.autencoders.VAE. This gives me the following results (on MNIST):
26_mnist
(on my dataset):

I also tried another (smaller) convolutional architecture (with different configurations, like transposed convolutions, upsample then convolution, binary cross entropy or L1 loss) but I either get those results with no sample diversity, or the training collapses completely (completely white / black images). It really seems I’m missing something with the convolutional architectures.

So if anybody is aware of any gross mistake in this code, or a well known trick that needs to be followed to train convolutional VAEs, I’m all ears :slight_smile: