python
import torch
import torch.nn.functional as F
from torch import nn
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
class Encoder(nn.Module):
def \__init_\_(self, input_dim=784, hidden_dim=400, latent_dim=20):
super().\__init_\_()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
def forward(self, x):
h = torch.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
class Decoder(nn.Module):
def \__init_\_(self, latent_dim=20, hidden_dim=400, output_dim=784):
super().\__init_\_()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
h = torch.relu(self.fc1(z))
x_recon = torch.sigmoid(self.fc2(h))
return x_recon
class VAE(nn.Module):
def \__init_\_(self, input_dim=784, hidden_dim=400, latent_dim=20):
super().\__init_\_()
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
self.flat = nn.Flatten(start_dim=1)
def reparametrization(self, mu, logvar):
sigma = torch.exp(0.5 \* logvar)
eps = torch.randn_like(sigma)
return mu + sigma \* eps
def loss(self, recons, x_true, mu, logvar):
x_gt = self.flat(x_true)
recon_term = F.binary_cross_entropy(recons, x_gt, reduction="sum")
kl_term = -0.5 \* torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_term + kl_term
def forward(self, x):
x = self.flat(x)
mu, logvar = self.encoder(x)
z = self.reparametrization(mu, logvar)
recons = self.decoder(z)
return recons, mu, logvar
def fit(model, dataloader, optimizer, device):
epochs = 200
for epoch in tqdm(range(epochs)):
train_one_epoch(model, dataloader, optimizer, device)
if epoch % 20 == 0:
sample_and_save(model, device)
def train_one_epoch(model, dataloader, optimizer, device):
model.train()
for image, \_ in dataloader:
image = image.to(device)
recons, mu, logvar = model(image)
loss = model.loss(recons, image, mu, logvar)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def sample_and_save(model, device, num_samples=9):
model.eval()
with torch.no_grad():
\# Sample from standard normal distribution
z = torch.randn(num_samples, 20).to(device) # 20 is the latent_dim
\# Generate images
samples = model.decoder(z)
\# Reshape to image format (batch, 1, 28, 28)
samples = samples.view(num_samples, 1, 28, 28)
\# Create grid (3x3)
grid = make_grid(samples, nrow=3, normalize=True)
\# Save
save_image(grid, "vae_samples.png")
print("Saved generated samples to vae_samples.png")
if _name_ == ā_main_ā:
transform = transforms.Compose(
\[
transforms.ToTensor(),
\]
)
\# Load dataset
dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
\# Create dataloader
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
fit(model, dataloader, optimizer, device)