My VAE PyTorch Implementation for MNIST

python

import torch

import torch.nn.functional as F

from torch import nn

import torch

from torch.utils.data import DataLoader

from torchvision import datasets, transforms

from torchvision.utils import make_grid, save_image

from tqdm import tqdm

class Encoder(nn.Module):

def \__init_\_(self, input_dim=784, hidden_dim=400, latent_dim=20):

    super().\__init_\_()

    self.fc1 = nn.Linear(input_dim, hidden_dim)

    self.fc_mu = nn.Linear(hidden_dim, latent_dim)

    self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

def forward(self, x):

    h = torch.relu(self.fc1(x))

    mu = self.fc_mu(h)

    logvar = self.fc_logvar(h)

    return mu, logvar

class Decoder(nn.Module):

def \__init_\_(self, latent_dim=20, hidden_dim=400, output_dim=784):

    super().\__init_\_()

    self.fc1 = nn.Linear(latent_dim, hidden_dim)

    self.fc2 = nn.Linear(hidden_dim, output_dim)

def forward(self, z):

    h = torch.relu(self.fc1(z))

    x_recon = torch.sigmoid(self.fc2(h))

    return x_recon

class VAE(nn.Module):

def \__init_\_(self, input_dim=784, hidden_dim=400, latent_dim=20):

    super().\__init_\_()

    self.encoder = Encoder(input_dim, hidden_dim, latent_dim)

    self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

    self.flat = nn.Flatten(start_dim=1)

def reparametrization(self, mu, logvar):

    sigma = torch.exp(0.5 \* logvar)

    eps = torch.randn_like(sigma)

    return mu + sigma \* eps

def loss(self, recons, x_true, mu, logvar):

    x_gt = self.flat(x_true)

    recon_term = F.binary_cross_entropy(recons, x_gt, reduction="sum")

    kl_term = -0.5 \* torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_term + kl_term

def forward(self, x):

    x = self.flat(x)

    mu, logvar = self.encoder(x)

    z = self.reparametrization(mu, logvar)

    recons = self.decoder(z)

    return recons, mu, logvar

def fit(model, dataloader, optimizer, device):

epochs = 200

for epoch in tqdm(range(epochs)):

    train_one_epoch(model, dataloader, optimizer, device)

    if epoch % 20 == 0:

        sample_and_save(model, device)

def train_one_epoch(model, dataloader, optimizer, device):

model.train()

for image, \_ in dataloader:

    image = image.to(device)

    recons, mu, logvar = model(image)

    loss = model.loss(recons, image, mu, logvar)

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

def sample_and_save(model, device, num_samples=9):

model.eval()

with torch.no_grad():

    \# Sample from standard normal distribution

    z = torch.randn(num_samples, 20).to(device)  # 20 is the latent_dim

    \# Generate images

    samples = model.decoder(z)

    \# Reshape to image format (batch, 1, 28, 28)

    samples = samples.view(num_samples, 1, 28, 28)

    \# Create grid (3x3)

    grid = make_grid(samples, nrow=3, normalize=True)

    \# Save

    save_image(grid, "vae_samples.png")

    print("Saved generated samples to vae_samples.png")

if _name_ == ā€œ_main_ā€:

transform = transforms.Compose(

    \[

        transforms.ToTensor(),

    \]

)

\# Load dataset

dataset = datasets.MNIST(

    root="./data", train=True, download=True, transform=transform

)

\# Create dataloader

dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

device = "cuda" if torch.cuda.is_available() else "cpu"

model = VAE().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

fit(model, dataloader, optimizer, device)

I’m unsure what your question is so could you add more information to your post, please?