Hi @ptrblck , the code snippet you provided calculate the standard deviation by averaging samples of the sd
from mini batches. While very close to the true sd
, it’s not calculated exactly. I wonder if the following would be better, albeit slower than your solution:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(1000, 3, 24, 24)
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
def online_mean_and_sd(loader):
"""Compute the mean and sd in an online fashion
Var[x] = E[X^2] - E^2[X]
"""
cnt = 0
fst_moment = torch.empty(3)
snd_moment = torch.empty(3)
for data in loader:
b, c, h, w = data.shape
nb_pixels = b * h * w
sum_ = torch.sum(data, dim=[0, 2, 3])
sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
cnt += nb_pixels
return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=1,
num_workers=1,
shuffle=False
)
mean, std = online_mean_and_sd(loader)