I am trying to code a Variational Autoencoder for MNIST dataset and the data pre-processing is as follows:
# Create transformations to be applied to dataset-
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,)
# (0.5,), (0.5,)
)
]
)
# Create training and validation datasets-
train_dataset = torchvision.datasets.MNIST(
# root = 'data', train = True,
root = path_to_data, train = True,
download = True, transform = transforms
)
val_dataset = torchvision.datasets.MNIST(
# root = 'data', train = False,
root = path_to_data, train = False,
download = True, transform = transforms
)
# Sanity check-
len(train_dataset), len(val_dataset)
# (60000, 10000)
# Create training and validation data loaders-
train_dataloader = torch.utils.data.DataLoader(
dataset = train_dataset, batch_size = 32,
shuffle = True,
# num_workers = 2
)
val_dataloader = torch.utils.data.DataLoader(
dataset = val_dataset, batch_size = 32,
shuffle = True,
# num_workers = 2
)
# Get a mini-batch of train data loaders-
imgs, labels = next(iter(train_dataloader))
imgs.shape, labels.shape
# (torch.Size([32, 1, 28, 28]), torch.Size([32]))
# Minimum & maximum pixel values-
imgs.min(), imgs.max()
# (tensor(-0.4242), tensor(2.8215))
# Compute min and max for train dataloader-
min_mnist, max_mnist = 0.0, 0.0
for img, _ in train_dataloader:
if img.min() < min_mnist:
min_mnist = img.min()
if img.max() > max_mnist:
max_mnist = img.max()
print(f"MNIST - train: min pixel value = {min_mnist:.4f} & max pixel value = {max_mnist:.4f}")
# MNIST - train: min pixel value = -0.4242 & max pixel value = 2.8215
min_mnist, max_mnist = 0.0, 0.0
for img, _ in val_dataloader:
if img.min() < min_mnist:
min_mnist = img.min()
if img.max() > max_mnist:
max_mnist = img.max()
print(f"MNIST - validation: min pixel value = {min_mnist:.4f} & max pixel value = {max_mnist:.4f}")
# MNIST - validation: min pixel value = -0.4242 & max pixel value = 2.8215
Using ‘ToTensor()’ and ‘Normalize()’ transforms, the output image pixels are in the range [-0.4242, 2.8215]. The output layer of the decoder within the VAE either uses the sigmoid or tanh activation function. Sigmoid outputs values in the range [0, 1], while tanh outputs values in the range[-1, 1].
This can be a problem since the input is in the range [-0.4242, 2.8215], while the output can be in the range [0, 1] or [-1, 1] depending on the activation being used - sigmoid or tanh.
One simple fix is to just use ‘ToTensor()’ transformation which scales the input in the range [0, 1] and then use sigmoid activation function for the output decoder layer within the VAE. But what’s a better approach for data pre-processing using images which need normalization with ‘Normalize()’ transformation for each of the channels such that the input and output/reconstructions are in the same range?