What is the equivalent of 'tf.keras.layers.experimental.preprocessing.Normalization' in Pytorch

I am trying to switch my model from tensorflow-keras to pytorch but I faced a problem. I have “tf.keras.layers.experimental.preprocessing.Normalization” in my model as a layer, so the model will save the mean and std as model parameters. Therefore, when it comes to testing, I can use these data to normalize my test_set. I do that because I want my trained model be usable with different users who don’t have access to the whole training data. How can I do the same thing in Pytorch?

The PyTorch way™ is to precompute these and provide the result (for use in torchvision.transforms.Normalize).

I was going to link to the forums for an example for how to do the precomputation, but I got confused with what people iterating lots of versions supposedly less buggy than the previous and so here is a quick and dirty example to recompute the dataset mean and std for CIFAR (you wouldn’t need that, because you can just compute mean and std in the dataset in one go, but this lets you check). The algorithm basically is the Welford-generalization to more than one thing in the update.

import torchvision
import torch.utils.data

ds = torchvision.datasets.CIFAR10('/mnt/data/vision/cifar10', transform=torchvision.transforms.ToTensor())
dl = torch.utils.data.DataLoader(ds, batch_size=1024)

count = 0
mean = torch.zeros((3,), dtype=torch.double)
M2 = torch.zeros((3,), dtype=torch.double)
for im, lab in dl:
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    im = im.to(dtype=torch.double)
    batch_count = im.size(0)
    batch_mean = im.mean(dim=(0, 2, 3))
    batch_M2 = im.var(dim=(0, 2, 3)) * batch_count
    delta = batch_mean - mean
    mean += delta * batch_count / (count + batch_count)
    M2 += batch_M2 + delta ** 2 * count * batch_count / (count + batch_count)
    count += batch_count
    
std = (M2 / count)**0.5

# for CIFAR, we can check
assert (torch.from_numpy(ds.data/255).mean(dim=(0, 1, 2)) - mean).abs().max() < 1e-6
assert (torch.from_numpy(ds.data/255).std(dim=(0, 1, 2)) - std).abs().max() < 1e-6

Best regards

Thomas

1 Like

Thanks Thomas. I also thought of logging the mean and std when I train my model and provide them as meta data to the user who is going to use the trained model for his/her test set, so s/he can use the transofrm to normalzie before feeding them to the pre trained model.