I hope this is OK to post here, as it’s not directly related to PyTorch. I’ll delete it if it doesn’t fit.
I am trying to train a variational autoencoder on small (64x64) grayscale patches. Ultimately, the downstream task is classification, but as most of my data is not labelled, semi-supervised learning seem like an interesting solution.
I am using PyTorch lightning, but training a VAE on my images lead to absolutely 0 sample variety. I have the following code :
import os
from torch.nn import Conv2d
from pytorch_lightning import Trainer
from pl_bolts.models.autoencoders import VAE
from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.callbacks.variational import LatentDimInterpolator as LDI
# interpolator callback
interpolator = LDI(interpolate_epoch_interval = 1)
# training
dataset = MNISTDataModule(num_workers = 8, batch_size = 128)
model = VAE(input_height = 28, enc_type = "resnet18", kl_coeff=1)
# changing first and last convolution to deal with 28x28 grayscale images
model.encoder.conv1 = Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.decoder.conv1 = Conv2d(64*model.decoder.expansion, 1, kernel_size=3, stride=1, padding=3, bias=False)
print(model)
# fit model
trainer = Trainer(gpus=1, max_epochs=50, callbacks = [interpolator], default_root_dir = os.getcwd())
trainer.fit(model, dataset)
trainer.save_checkpoint(os.getcwd())
I only modified the first and last convolution of the resnet decoder and encoder, because the images only have 1 channel. There are some small modifications in the callback code (I added a save directory to see the images as I’m not using tensorboard, and I changed a self.module
call into self.module.decoder()
as it was not running before).
This code gives me the following result (50 epochs, lr = 0.0001, Adam optimizer) :
On my dataset (grayscale chromosome patches) I get similar results :
This post (https://www.quora.com/How-do-you-fix-a-Variational-Autoencoder-VAE-that-suffers-from-mode-collapse) suggests that loss balancing could be an issue, but an hyperparameter search for kl_coeff didn’t turn anything up.
However, a fully connected VAE gives much better results, both on MNIST and on my chromosome dataset :
class encoder(nn.Module):
def __init__(self, img_height, h_dim1, h_dim2, **kwargs) -> None:
super(encoder, self).__init__()
self.fc1 = nn.Linear(img_height*img_height, h_dim1)
self.fc2 = nn.Linear(h_dim1, h_dim2)
def forward(self, x):
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
class decoder(nn.Module):
def __init__(self, z_dim, h_dim1, h_dim2, img_height):
super(decoder, self).__init__()
self.img_height = img_height
self.fc4 = nn.Linear(z_dim, h_dim2)
self.fc5 = nn.Linear(h_dim2, h_dim1)
self.fc6 = nn.Linear(h_dim1, img_height*img_height)
def forward(self, x):
x = F.relu(self.fc4(x))
x = F.relu(self.fc5(x))
x = F.sigmoid(self.fc6(x))
x = x.view(-1, 1, self.img_height, self.img_height)
return x
class VAE(pl.LightningModule):
"""
Standard VAE with Gaussian Prior and approx posterior.
Model is available pretrained on different datasets:
"""
def __init__(
self,
img_height:int = 28,
h_dim1:int = 512,
h_dim2:int = 256,
latent_dim:int = 2,
kl_coeff: float = 0.1,
lr: float = 1e-4,
**kwargs
):
"""
Args:
input_height: height of the images
enc_type: option between resnet18 or resnet50
enc_out_dim: set according to the out_channel count of
encoder used (512 for resnet18, 2048 for resnet50)
kl_coeff: coefficient for kl term of the loss
latent_dim: dim of latent space
lr: learning rate for Adam
"""
super(VAE, self).__init__()
self.save_hyperparameters()
self.lr = lr
self.kl_coeff = kl_coeff
self.latent_dim = latent_dim
self.h_dim1 = h_dim1
self.h_dim2 = h_dim2
self.img_height = img_height
self.fc_mu = nn.Linear(self.h_dim2, self.latent_dim)
self.fc_var = nn.Linear(self.h_dim2, self.latent_dim)
self.encoder = encoder(self.img_height, self.h_dim1, self.h_dim2)
self.decoder = decoder(self.latent_dim, self.h_dim1, self.h_dim2, self.img_height)
The following training code is the same as the one in pl_bolts.models.autencoders.VAE. This gives me the following results (on MNIST):
(on my dataset):
I also tried another (smaller) convolutional architecture (with different configurations, like transposed convolutions, upsample then convolution, binary cross entropy or L1 loss) but I either get those results with no sample diversity, or the training collapses completely (completely white / black images). It really seems I’m missing something with the convolutional architectures.
So if anybody is aware of any gross mistake in this code, or a well known trick that needs to be followed to train convolutional VAEs, I’m all ears