About Normalization using pre-trained vgg16 networks

Hi @ptrblck , the code snippet you provided calculate the standard deviation by averaging samples of the sd from mini batches. While very close to the true sd, it’s not calculated exactly. I wonder if the following would be better, albeit slower than your solution:

import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(1000, 3, 24, 24)
        
    def __getitem__(self, index):
        x = self.data[index]
        return x

    def __len__(self):
        return len(self.data)

def online_mean_and_sd(loader):
    """Compute the mean and sd in an online fashion

        Var[x] = E[X^2] - E^2[X]
    """
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for data in loader:

        b, c, h, w = data.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(data, dim=[0, 2, 3])
        sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)

        cnt += nb_pixels

    return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)

    

dataset = MyDataset()
loader = DataLoader(
    dataset,
    batch_size=1,
    num_workers=1,
    shuffle=False
)

mean, std = online_mean_and_sd(loader)
10 Likes