import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
class Discriminator(nn.Module):
def init(self, img_channels=3, features=64): # 3x64x64
super().init()
self.disc = nn.Sequential(
self._make_conv2d_block(img_channels, features, use_bn=False), # 64x32x32
self._make_conv2d_block(features, features*2), # 128x16x16
self._make_conv2d_block(features*2, features*4), # 256x8x8
self._make_conv2d_block(features*4, features*8), # 512x4x4
self._make_conv2d_block(features*8, 1, 4, 2, 0, use_act=False), # 1x1x1
nn.Sigmoid(),
)
def _make_conv2d_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True, use_act=True, leak=0.2):
layers = [
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
),
]
if use_bn:
layers.insert(1, nn.BatchNorm2d(out_channels))
if use_act:
layers.insert(2, nn.LeakyReLU(leak))
return nn.Sequential(*layers)
def forward(self, x):
return self.disc(x)
class Generator(nn.Module):
def init(self, z_dim=100, features=64, img_channels=3): # 100x1x1
super().init()
self.gen = nn.Sequential(
self._make_convT2d_block(z_dim, features*16, 4, 1, 0), # 1024x4x4
self._make_convT2d_block(features*16, features*8), # 512x8x8
self._make_convT2d_block(features*8, features*4), # 256x16x16
self._make_convT2d_block(features*4, features*2), # 512x32x32
self._make_convT2d_block(features*2, img_channels, use_bn=False, use_act=False), # 3x64x64
nn.Tanh(),
)
def _make_convT2d_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True, use_act=True):
layers = [
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
)
]
if use_bn:
layers.insert(1, nn.BatchNorm2d(out_channels))
if use_act:
layers.insert(2, nn.ReLU())
return nn.Sequential(*layers)
def forward(self, x):
return self.gen(x)
def initialize_weights(model):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(
tensor=m.weight.data,
mean=0.0,
std=0.02,
)
BATCH_SIZE, NUM_CHANNELS, H, W = 8, 3, 64, 64
Z_DIM = 100
FEATURES = 64
img_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, H, W))
disc = Discriminator(NUM_CHANNELS, FEATURES)
assert disc(img_test).shape == (BATCH_SIZE, 1, 1, 1), ‘Disc Test Failed’
noise_test = torch.randn((BATCH_SIZE, Z_DIM, 1, 1))
gen = Generator(Z_DIM, FEATURES, NUM_CHANNELS)
assert gen(noise_test).shape == (BATCH_SIZE, NUM_CHANNELS, H, W), ‘Gen Test Failed’
print(‘Success!’) # Success
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
NUM_CHANNELS = 1
BATCH_SIZE = 64
IMG_SIZE = 64
Z_DIM = 100
FEATURES_DISC = 64
FEATURES_GEN = 64
LR = 2e-4
NUM_EPOCHS = 5
DISC_ITER = 1
GRID_SHOW = 8
disc = Discriminator(NUM_CHANNELS, FEATURES_DISC).to(device)
gen = Generator(Z_DIM, FEATURES_GEN, NUM_CHANNELS).to(device)
initialize_weights(disc), initialize_weights(gen)
fixed_noise = torch.randn((GRID_SHOW, Z_DIM, 1, 1)).to(device)
transformations = transforms.Compose(
[
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
[0.5 for _ in range(NUM_CHANNELS)], [0.5 for _ in range(NUM_CHANNELS)]
),
]
)
images = datasets.MNIST(root=‘mnist’, train=True, transform=transformations, download=True)
images_loader = DataLoader(images, batch_size=BATCH_SIZE, shuffle=True)
optim_disc = optim.Adam(disc.parameters(), lr=LR, betas=(0.5, 0.999))
optim_gen = optim.Adam(gen.parameters(), lr=LR, betas=(0.5, 0.999))
criterion = nn.BCELoss()
writer_fake = SummaryWriter(‘logs/fake’)
writer_real = SummaryWriter(‘logs/real’)
steps = 1
disc.train(), gen.train()
for epoch in range(NUM_EPOCHS):
for batch_idx, (real, _) in enumerate(images_loader):
real = real.to(device)
for _ in range(DISC_ITER):
noise = torch.randn((real.shape[0], Z_DIM, 1, 1)).to(device)
fake = gen(noise)
disc_real = disc(real).view(-1)
loss_disc_real = criterion(
input=disc_real,
target=torch.ones_like(disc_real),
)
disc_fake = disc(fake.detach()).view(-1)
loss_disc_fake = criterion(
input=disc_fake,
target=torch.zeros_like(disc_fake),
)
loss_disc = loss_disc_real + loss_disc_fake
disc.zero_grad()
loss_disc.backward()
optim_disc.step()
disc_fake = disc(fake).view(-1)
loss_gen = criterion(
input=disc_fake,
target=torch.ones_like(disc_fake),
)
gen.zero_grad()
loss_gen.backward()
optim_gen.step()
if batch_idx % 50 == 0:
print(
f'Epoch: {epoch+1}/{NUM_EPOCHS} -- Step: {steps} -- Batch: {batch_idx+1}/{len(images_loader)} -- Disc Loss: {loss_disc:.4f} -- Gen Loss: {loss_gen:.4f}'
)
with torch.no_grad():
fake = gen(fixed_noise)
fake_images = torchvision.utils.make_grid(fake, nrow=GRID_SHOW, normalize=True)
real_images = torchvision.utils.make_grid(real[:GRID_SHOW], nrow=GRID_SHOW, normalize=True)
writer_fake.add_image(
'Fake', fake_images, global_step=steps
)
writer_real.add_image(
'Real', real_images, global_step=steps
)
steps += 1
Epoch: 1/5 – Step: 1 – Batch: 1/938 – Disc Loss: 1.3863 – Gen Loss: 0.6932
Epoch: 1/5 – Step: 2 – Batch: 51/938 – Disc Loss: 1.3863 – Gen Loss: 0.6932
Epoch: 1/5 – Step: 3 – Batch: 101/938 – Disc Loss: 1.3863 – Gen Loss: 0.6932
Epoch: 1/5 – Step: 4 – Batch: 151/938 – Disc Loss: 1.3863 – Gen Loss: 0.6931
Epoch: 1/5 – Step: 5 – Batch: 201/938 – Disc Loss: 1.3863 – Gen Loss: 0.6931
Epoch: 1/5 – Step: 6 – Batch: 251/938 – Disc Loss: 1.3863 – Gen Loss: 0.6932
Epoch: 1/5 – Step: 7 – Batch: 301/938 – Disc Loss: 1.3863 – Gen Loss: 0.6931