Since last year, I’ve been using means and standard deviations that are far from my dataset and would like to know how to compute them in pytorch.
Hi @Omar_Zayed,
you can just iterate over your dataset, collect data and then compute statistics.
Here an example with MNIST dataset:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torch
dataset = MNIST(root=".", download=True, transform=ToTensor())
dt = DataLoader(dataset, batch_size=8)
data = []
for sample, label in dt:
data.append(sample)
data_stacked = torch.stack(data, dim=0)
# stats per channel:
mean = data_stacked.mean(dim=[0, 1, 3, 4])
std = data_stacked.std(dim=[0, 1, 3, 4])
print(f"mean: {mean}")
print(f"std: {std}")