GANs Pix2Pix Generator image brightness and standardizing training set

I have trained my implementation of Pix2Pix on the face2comics dataset and although the generated images are sharp and realistic, they are too bright. A tanh activation in the last layer of the generator outputs the generated images in the range [-1, 1]. The training images are normalized to have zero mean and std of one by computing the statistics of the training set.


However, when I check if the training and output images are in those ranges, only the generated images are in the desired range and individual training images are well above and/or below the [-1, 1] range. It is only when I compute the statistics of the entire training set after normalization that I get zero mean and unit variance.

Is this behaviour expected? How can I fix the “saturation” of the generated images?

Dataset module:

import os
import torch

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm


DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()


def mean_std(dataset):
    """Return the mean and std of the dataset."""
    loader = DataLoader(dataset, batch_size=128, num_workers=0, shuffle=False)

    mean_inputs = 0.
    std_inputs = 0.
    mean_targets = 0.
    std_targets = 0.

    for inputs, targets in tqdm(loader):

        inputs = inputs.to(DEVICE).view(inputs.size(0), inputs.size(1), -1)
        mean_inputs += inputs.mean(2).sum(0)
        std_inputs += inputs.std(2).sum(0)

        targets = targets.to(DEVICE).view(targets.size(0), inputs.size(1), -1)
        mean_targets += targets.mean(2).sum(0)
        std_targets += targets.std(2).sum(0)

    mean_inputs /= len(loader.dataset)
    std_inputs /= len(loader.dataset)
    mean_targets /= len(loader.dataset)
    std_targets /= len(loader.dataset)

    return (mean_inputs, std_inputs), (mean_targets, std_targets)


class Face2Comic(Dataset):
    """A paired face-to-comics dataset."""

    def __init__(self, data_dir, train=True):
        super(Face2Comic, self).__init__()
        self.data_dir = data_dir
        self.faces_dir = os.path.join(data_dir, "faces")
        self.faces = os.listdir(self.faces_dir)
        self.comics_dir = os.path.join(data_dir, "comics")
        self.comics = os.listdir(self.comics_dir)
        self.len = len(self.faces)
        self.train = train

    def apply_transforms(self, face, comic):
        """Apply the same transforms to the input and the target."""
        common_transform = transforms.Compose([transforms.Resize((256, 256)),
                                               transforms.ToTensor()])

        normalize_face = transforms.Normalize(mean=[0.5129, 0.4136, 0.3671],
                                              std=[0.2372, 0.1972, 0.1883])

        normalize_comic = transforms.Normalize(mean=[0.4445, 0.3650, 0.3226],
                                               std=[0.2594, 0.2051, 0.1840])

        face = normalize_face(common_transform(face))
        comic = normalize_comic(common_transform(comic))

        if self.train:
            train_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ColorJitter()])
            face = train_transform(face)
            comic = train_transform(comic)

        return face, comic

    def __len__(self):
        """Get the number of samples in the dataset."""
        return self.len

    def __getitem__(self, index):
        """Return the transformed input (face) and target (comic)."""
        face = Image.open(os.path.join(self.faces_dir, self.faces[index]))
        comic = Image.open(os.path.join(self.comics_dir, self.comics[index]))
        return self.apply_transforms(face, comic)


if __name__ == '__main__':

    data_dir_train = os.getcwd() + '\\data\\train\\'
    dataset_train = Face2Comic(data_dir=data_dir_train, train=True)

    stats_faces, stats_comics = mean_std(dataset_train)

    print(f"Faces: mean = {stats_faces[0]}, std = {stats_faces[1]}")
    print(f"Comics: mean = {stats_comics[0]}, std = {stats_comics[1]}")

    data_dir_val = os.getcwd() + '\\data\\val\\'
    dataset_val = Face2Comic(data_dir=data_dir_val, train=False)

    stats_faces, stats_comics = mean_std(dataset_val)

    print(f"Faces: mean = {stats_faces[0]}, std = {stats_faces[1]}")
    print(f"Comics: mean = {stats_comics[0]}, std = {stats_comics[1]}")

Model module:

import torch
import torch.nn as nn


DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()


def make_conv(in_size, out_size, encode, batch_norm, activation, drop_out):
    """Convolutional blocks of the Generator and the Discriminator.

    Let Ck denote a Convolution-BtachNorm-ReLU block with k filters.
    CDk denotes a Convolution-BtachNorm-Dropout-ReLU block with 50% dropout.
    All convolutions are 4 x 4 spatial filters with stride 2. Convolutions in
    the encoder and discriminator downsample by a factor of 2, whereas in the
    decoder they upsample by a factor of 2.
    """
    block = [nn.Conv2d(in_size, out_size,
                       kernel_size=4, stride=2, padding=1,
                       padding_mode="reflect",
                       bias=False if batch_norm else True)
             if encode else
             nn.ConvTranspose2d(in_size, out_size,
                                kernel_size=4, stride=2, padding=1,
                                bias=False if batch_norm else True)]

    if batch_norm:
        block.append(nn.BatchNorm2d(out_size))
    if activation == "leaky":
        block.append(nn.LeakyReLU(0.2))
    elif activation == "sigmoid":
        block.append(nn.Sigmoid())
    elif activation == "tanh":
        block.append(nn.Tanh())
    elif activation == "relu":
        block.append(nn.ReLU())
    if drop_out:
        block.append(nn.Dropout(0.5))

    return nn.Sequential(*block)


def init_weights(model, mean=0.0, std=0.02):
    """Initialize weights from a Gaussian distribution."""
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
            nn.init.normal_(module.weight.data, mean=mean, std=std)


class Generator(nn.Module):
    """UNet Generator architecture.

    encoder:
        C64-C128-C256-C512-C512-C512-C512-C512
    decoder:
        CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128

    After the C128 block in the decoder, a convolution is applied to map to the
    number of output channels, followed by a Tanh function. BatchNorm is not
    applied to the C64 block in the encoder. All ReLUs in the econder are
    leaky with slope 0.2, while ReLUs in the decoder are not leaky.
    """

    def __init__(self, in_channels=3, out_channels=3):
        super(Generator, self).__init__()

        encoder = [in_channels, 64, 128, 256, 512, 512, 512, 512, 512]
        encoder = zip(encoder, encoder[1:])

        self.encoder = nn.ModuleList()
        for idx, (input_size, output_size) in enumerate(encoder):
            if idx == 0:
                input_size *= 2
                batch_norm = False
            else:
                batch_norm = True
            self.encoder.append(make_conv(in_size=input_size,
                                          out_size=output_size,
                                          encode=True,
                                          batch_norm=batch_norm,
                                          activation="leaky",
                                          drop_out=False))

        decoder = [512, 1024, 1024, 1024, 1024, 512, 256, 128, out_channels]
        layers_decoder = len(decoder)
        decoder = zip(decoder, decoder[1:])

        self.decoder = nn.ModuleList()
        for idx, (input_size, output_size) in enumerate(decoder):
            if idx < layers_decoder - 2:
                batch_norm = True
                activation = "relu"
                output_size //= 2
            else:
                batch_norm = False
                activation = "tanh"
            self.decoder.append(make_conv(in_size=input_size,
                                          out_size=output_size,
                                          encode=False,
                                          batch_norm=batch_norm,
                                          activation=activation,
                                          drop_out=True if idx < 3 else False))

        init_weights(self, mean=0.0, std=0.02)

    def forward(self, x, z):
        """Generate a translation of x conditioned on the noise z."""
        x = torch.cat((x, z), dim=1)
        skip = [None]*len(self.encoder)

        for idx, block in zip(range(len(skip)-1, -1, -1), self.encoder):
            x = block(x)
            skip[idx] = x

        for idx, block in enumerate(self.decoder):
            if idx > 0:
                x = torch.cat((x, skip[idx]), dim=1)
            x = block(x)

        return x


class Discriminator(nn.Module):
    """C64-C128-C256-C512 PatchGAN Discriminator architecture.

    After the C512 block, a convolution is applied to map to a 1-d output,
    followed by a Sigmoid function. BatchNorm is not applied to the c64 block.
    All ReLUs are leaky with slope of 0.2.
    """

    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        channels = [in_channels, 64, 128, 256, 512, 1]
        layers = len(channels)
        channels = zip(channels, channels[1:])

        self.blocks = nn.ModuleList()
        for layer, (input_size, output_size) in enumerate(channels):
            if layer == 0:
                input_size *= 2
                batch_norm = False
                activation = "leaky"
            elif layer < layers - 2:
                batch_norm = True
                activation = "leaky"
            else:
                batch_norm = False
                activation = "sigmoid"
            self.blocks.append(make_conv(in_size=input_size,
                                         out_size=output_size,
                                         encode=True,
                                         batch_norm=batch_norm,
                                         activation=activation,
                                         drop_out=False))

        init_weights(self, mean=0.0, std=0.02)

    def forward(self, x, y):
        """Return a nxn tensor of patch probabilities."""
        x = torch.cat((x, y), dim=1)
        for block in self.blocks:
            x = block(x)
        return x


if __name__ == '__main__':

    batch_size = 8
    channels = 3
    height = 256
    width = 256

    x = torch.randn((batch_size, channels, height, width), device=DEVICE)
    y = torch.randn((batch_size, channels, height, width), device=DEVICE)
    z = torch.randn((batch_size, channels, height, width), device=DEVICE)

    generator = Generator().to(DEVICE)
    total_params = sum(p.numel() for p in generator.parameters())
    print(f"Number of parameters in Generator: {total_params:,}")

    G_z = generator(x, z)
    print(G_z.shape)

    discriminator = Discriminator().to(DEVICE)
    total_params = sum(p.numel() for p in discriminator.parameters())
    print(f"Number of parameters in Discriminator: {total_params:,}")

    D_x = discriminator(x, y)
    print(D_x.shape)