I have trained my implementation of Pix2Pix on the face2comics dataset and although the generated images are sharp and realistic, they are too bright. A tanh
activation in the last layer of the generator outputs the generated images in the range [-1, 1]
. The training images are normalized to have zero mean
and std
of one by computing the statistics of the training set.
However, when I check if the training and output images are in those ranges, only the generated images are in the desired range and individual training images are well above and/or below the [-1, 1]
range. It is only when I compute the statistics of the entire training set after normalization that I get zero mean and unit variance.
Is this behaviour expected? How can I fix the “saturation” of the generated images?
Dataset module:
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
def mean_std(dataset):
"""Return the mean and std of the dataset."""
loader = DataLoader(dataset, batch_size=128, num_workers=0, shuffle=False)
mean_inputs = 0.
std_inputs = 0.
mean_targets = 0.
std_targets = 0.
for inputs, targets in tqdm(loader):
inputs = inputs.to(DEVICE).view(inputs.size(0), inputs.size(1), -1)
mean_inputs += inputs.mean(2).sum(0)
std_inputs += inputs.std(2).sum(0)
targets = targets.to(DEVICE).view(targets.size(0), inputs.size(1), -1)
mean_targets += targets.mean(2).sum(0)
std_targets += targets.std(2).sum(0)
mean_inputs /= len(loader.dataset)
std_inputs /= len(loader.dataset)
mean_targets /= len(loader.dataset)
std_targets /= len(loader.dataset)
return (mean_inputs, std_inputs), (mean_targets, std_targets)
class Face2Comic(Dataset):
"""A paired face-to-comics dataset."""
def __init__(self, data_dir, train=True):
super(Face2Comic, self).__init__()
self.data_dir = data_dir
self.faces_dir = os.path.join(data_dir, "faces")
self.faces = os.listdir(self.faces_dir)
self.comics_dir = os.path.join(data_dir, "comics")
self.comics = os.listdir(self.comics_dir)
self.len = len(self.faces)
self.train = train
def apply_transforms(self, face, comic):
"""Apply the same transforms to the input and the target."""
common_transform = transforms.Compose([transforms.Resize((256, 256)),
transforms.ToTensor()])
normalize_face = transforms.Normalize(mean=[0.5129, 0.4136, 0.3671],
std=[0.2372, 0.1972, 0.1883])
normalize_comic = transforms.Normalize(mean=[0.4445, 0.3650, 0.3226],
std=[0.2594, 0.2051, 0.1840])
face = normalize_face(common_transform(face))
comic = normalize_comic(common_transform(comic))
if self.train:
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter()])
face = train_transform(face)
comic = train_transform(comic)
return face, comic
def __len__(self):
"""Get the number of samples in the dataset."""
return self.len
def __getitem__(self, index):
"""Return the transformed input (face) and target (comic)."""
face = Image.open(os.path.join(self.faces_dir, self.faces[index]))
comic = Image.open(os.path.join(self.comics_dir, self.comics[index]))
return self.apply_transforms(face, comic)
if __name__ == '__main__':
data_dir_train = os.getcwd() + '\\data\\train\\'
dataset_train = Face2Comic(data_dir=data_dir_train, train=True)
stats_faces, stats_comics = mean_std(dataset_train)
print(f"Faces: mean = {stats_faces[0]}, std = {stats_faces[1]}")
print(f"Comics: mean = {stats_comics[0]}, std = {stats_comics[1]}")
data_dir_val = os.getcwd() + '\\data\\val\\'
dataset_val = Face2Comic(data_dir=data_dir_val, train=False)
stats_faces, stats_comics = mean_std(dataset_val)
print(f"Faces: mean = {stats_faces[0]}, std = {stats_faces[1]}")
print(f"Comics: mean = {stats_comics[0]}, std = {stats_comics[1]}")
Model module:
import torch
import torch.nn as nn
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
def make_conv(in_size, out_size, encode, batch_norm, activation, drop_out):
"""Convolutional blocks of the Generator and the Discriminator.
Let Ck denote a Convolution-BtachNorm-ReLU block with k filters.
CDk denotes a Convolution-BtachNorm-Dropout-ReLU block with 50% dropout.
All convolutions are 4 x 4 spatial filters with stride 2. Convolutions in
the encoder and discriminator downsample by a factor of 2, whereas in the
decoder they upsample by a factor of 2.
"""
block = [nn.Conv2d(in_size, out_size,
kernel_size=4, stride=2, padding=1,
padding_mode="reflect",
bias=False if batch_norm else True)
if encode else
nn.ConvTranspose2d(in_size, out_size,
kernel_size=4, stride=2, padding=1,
bias=False if batch_norm else True)]
if batch_norm:
block.append(nn.BatchNorm2d(out_size))
if activation == "leaky":
block.append(nn.LeakyReLU(0.2))
elif activation == "sigmoid":
block.append(nn.Sigmoid())
elif activation == "tanh":
block.append(nn.Tanh())
elif activation == "relu":
block.append(nn.ReLU())
if drop_out:
block.append(nn.Dropout(0.5))
return nn.Sequential(*block)
def init_weights(model, mean=0.0, std=0.02):
"""Initialize weights from a Gaussian distribution."""
for module in model.modules():
if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
nn.init.normal_(module.weight.data, mean=mean, std=std)
class Generator(nn.Module):
"""UNet Generator architecture.
encoder:
C64-C128-C256-C512-C512-C512-C512-C512
decoder:
CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
After the C128 block in the decoder, a convolution is applied to map to the
number of output channels, followed by a Tanh function. BatchNorm is not
applied to the C64 block in the encoder. All ReLUs in the econder are
leaky with slope 0.2, while ReLUs in the decoder are not leaky.
"""
def __init__(self, in_channels=3, out_channels=3):
super(Generator, self).__init__()
encoder = [in_channels, 64, 128, 256, 512, 512, 512, 512, 512]
encoder = zip(encoder, encoder[1:])
self.encoder = nn.ModuleList()
for idx, (input_size, output_size) in enumerate(encoder):
if idx == 0:
input_size *= 2
batch_norm = False
else:
batch_norm = True
self.encoder.append(make_conv(in_size=input_size,
out_size=output_size,
encode=True,
batch_norm=batch_norm,
activation="leaky",
drop_out=False))
decoder = [512, 1024, 1024, 1024, 1024, 512, 256, 128, out_channels]
layers_decoder = len(decoder)
decoder = zip(decoder, decoder[1:])
self.decoder = nn.ModuleList()
for idx, (input_size, output_size) in enumerate(decoder):
if idx < layers_decoder - 2:
batch_norm = True
activation = "relu"
output_size //= 2
else:
batch_norm = False
activation = "tanh"
self.decoder.append(make_conv(in_size=input_size,
out_size=output_size,
encode=False,
batch_norm=batch_norm,
activation=activation,
drop_out=True if idx < 3 else False))
init_weights(self, mean=0.0, std=0.02)
def forward(self, x, z):
"""Generate a translation of x conditioned on the noise z."""
x = torch.cat((x, z), dim=1)
skip = [None]*len(self.encoder)
for idx, block in zip(range(len(skip)-1, -1, -1), self.encoder):
x = block(x)
skip[idx] = x
for idx, block in enumerate(self.decoder):
if idx > 0:
x = torch.cat((x, skip[idx]), dim=1)
x = block(x)
return x
class Discriminator(nn.Module):
"""C64-C128-C256-C512 PatchGAN Discriminator architecture.
After the C512 block, a convolution is applied to map to a 1-d output,
followed by a Sigmoid function. BatchNorm is not applied to the c64 block.
All ReLUs are leaky with slope of 0.2.
"""
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()
channels = [in_channels, 64, 128, 256, 512, 1]
layers = len(channels)
channels = zip(channels, channels[1:])
self.blocks = nn.ModuleList()
for layer, (input_size, output_size) in enumerate(channels):
if layer == 0:
input_size *= 2
batch_norm = False
activation = "leaky"
elif layer < layers - 2:
batch_norm = True
activation = "leaky"
else:
batch_norm = False
activation = "sigmoid"
self.blocks.append(make_conv(in_size=input_size,
out_size=output_size,
encode=True,
batch_norm=batch_norm,
activation=activation,
drop_out=False))
init_weights(self, mean=0.0, std=0.02)
def forward(self, x, y):
"""Return a nxn tensor of patch probabilities."""
x = torch.cat((x, y), dim=1)
for block in self.blocks:
x = block(x)
return x
if __name__ == '__main__':
batch_size = 8
channels = 3
height = 256
width = 256
x = torch.randn((batch_size, channels, height, width), device=DEVICE)
y = torch.randn((batch_size, channels, height, width), device=DEVICE)
z = torch.randn((batch_size, channels, height, width), device=DEVICE)
generator = Generator().to(DEVICE)
total_params = sum(p.numel() for p in generator.parameters())
print(f"Number of parameters in Generator: {total_params:,}")
G_z = generator(x, z)
print(G_z.shape)
discriminator = Discriminator().to(DEVICE)
total_params = sum(p.numel() for p in discriminator.parameters())
print(f"Number of parameters in Discriminator: {total_params:,}")
D_x = discriminator(x, y)
print(D_x.shape)