How can I get the statistics automatically for Normalize()

Since last year, I’ve been using means and standard deviations that are far from my dataset and would like to know how to compute them in pytorch.

Hi @Omar_Zayed,
you can just iterate over your dataset, collect data and then compute statistics.

Here an example with MNIST dataset:

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from import DataLoader
import torch

dataset = MNIST(root=".", download=True, transform=ToTensor())
dt = DataLoader(dataset, batch_size=8)

data = []

for sample, label in dt:

data_stacked = torch.stack(data, dim=0)

# stats per channel:
mean = data_stacked.mean(dim=[0, 1, 3, 4])
std = data_stacked.std(dim=[0, 1, 3, 4])

print(f"mean: {mean}")
print(f"std: {std}")