Usually if your use case stays in the same data domain, the mean and std won’t be that different and you can try to use the ImageNet statistics.
I would recommend to use your own data statistics if you are dealing with another domain, e.g. medical images.
If you are in doubt that the ImageNet stats are a good fit for your data, I would re-calculate them on my dataset.
Thanks. I may know how to calculate the mean of my dataset, however, I do not know how to get the std of my dataset in pytorch effectively.
Can you provide me with your code to calculate the mean and std of ImageNet?I will appreciate very much!
Thanks very much!
This should work:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(100, 3, 24, 24)
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=10,
num_workers=1,
shuffle=False
)
mean = 0.
std = 0.
nb_samples = 0.
for data in loader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
std /= nb_samples
If you can load all samples directly into your RAM, the code will be a bit shorter, but I just assume that’s not possible.
Thanks very much! Appreciate it!
Hi @ptrblck , the code snippet you provided calculate the standard deviation by averaging samples of the sd
from mini batches. While very close to the true sd
, it’s not calculated exactly. I wonder if the following would be better, albeit slower than your solution:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(1000, 3, 24, 24)
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
def online_mean_and_sd(loader):
"""Compute the mean and sd in an online fashion
Var[x] = E[X^2] - E^2[X]
"""
cnt = 0
fst_moment = torch.empty(3)
snd_moment = torch.empty(3)
for data in loader:
b, c, h, w = data.shape
nb_pixels = b * h * w
sum_ = torch.sum(data, dim=[0, 2, 3])
sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
cnt += nb_pixels
return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=1,
num_workers=1,
shuffle=False
)
mean, std = online_mean_and_sd(loader)
Thanks for the code!
Your standard deviation calculation seems to come closer to dataset.data.std([0, 2, 3])
!
This is computing the average standard deviation, which is going to be different from std()
Good answer. Thank you
Sorry for reviving this thread, but your current approach gives a high error in the estimation of std
, if I add an offset to the random data:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(1000, 3, 24, 24) + 10000
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
def online_mean_and_sd(loader):
"""Compute the mean and sd in an online fashion
Var[x] = E[X^2] - E^2[X]
"""
cnt = 0
fst_moment = torch.empty(3)
snd_moment = torch.empty(3)
for data in loader:
b, c, h, w = data.shape
nb_pixels = b * h * w
sum_ = torch.sum(data, dim=[0, 2, 3])
sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
cnt += nb_pixels
return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=1,
num_workers=1,
shuffle=False
)
mean, std = online_mean_and_sd(loader)
print(mean, dataset.data.mean([0, 2, 3]))
> tensor([10000.0039, 9999.9990, 10000.0020]) tensor([10000.0010, 10000.0000, 10000.0000])
print(std, dataset.data.std([0, 2, 3]))
> tensor([15.4919, 18.7617, 16.7332]) tensor([0.9995, 0.9994, 1.0010])
Correct. I’ve tested on CIFAR10 and the mean of pixel stds is [51.6, 50.8, 51.2]
while the pixel std is [63.0, 62.1, 66.7]
, which is quite different.
The code below lacks optimization but gives the correct std:
pixel_mean = np.zeros(3)
pixel_std = np.zeros(3)
k = 1
for image, _ in tqdm(dataset, "Computing mean/std", len(dataset), unit="samples"):
image = np.array(image)
pixels = image.reshape((-1, image.shape[2]))
for pixel in pixels:
diff = pixel - pixel_mean
pixel_mean += diff / k
pixel_std += diff * (pixel - pixel_mean)
k += 1
pixel_std = np.sqrt(pixel_std / (k - 2))
print(pixel_mean)
print(pixel_std)
Adapted from this SO answer.
This works for me.
I adjust your answer. Thanks a lot
pixel_mean = np.zeros(1)
pixel_std = np.zeros(1)
k = 1
for i_batch, sample_batched in enumerate(dataloader):
image = np.array(sample_batched['image'])
pixels = image.reshape((-1, image.shape[1]))
for pixel in pixels:
diff = pixel - pixel_mean
pixel_mean = pixel_mean + diff / k
pixel_std = pixel_std + diff * (pixel - pixel_mean)
k += 1
pixel_std = np.sqrt(pixel_std / (k - 2))
print(pixel_mean)# [5.50180734]
print(pixel_std)#[8.27773514]
Computation of std is wrong! std dev of dataset is NOT equal to avg of std dev of batches. Authors needs to edit their answers with “Do not use this!” warning.
Hi @ptrblck,
std can’t be added and then averaged to get the overall std, please have a look at this,https://stackoverflow.com/a/60803379/8063334
If you are looking for mean and std statistics for each channel of the dataset, instead of just the mean and std, I created the solution below. The code is mainly an extension of ptrblck’s code above
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(100, 3, 24, 24)
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=10,
num_workers=1,
shuffle=False
)
nb_samples = 0.
channel_mean = torch.Tensor([0., 0., 0.])
channel_std = torch.Tensor([0., 0., 0.])
for images in tqdm_notebook(loader):
# scale image to be between 0 and 1
images=images/255.
batch_samples = images.size(0)
images = images.view(batch_samples, images.size(1)*images.size(2), 3)
for i in range(3):
channel_mean[i]+=images[:, :,i].mean(1).sum(0)
channel_std[i]+=images[:, :,i].std(1).sum(0)
nb_samples += batch_samples
channel_mean /= nb_samples
channel_std /= nb_samples
I had a quick question about mean and std calculation for my custom dataset( around 5.5k train images). Now since these values depend on the batch size, if I were to change the batch_size (as a hyperparmeter), the mean and std need to be calculated each time the batch size changes? Since I haven’t done like this for CIFAR-10 or Imagenet
image normalization here isn’t based on batch statistics but dateset statistics. You calculate the mean and std of your entire dateset and use that to normalize each sample of the batch
Here you go, didn’t test it rigorously, but it gave the right values for a few test values:
class MyDataset(Dataset):
def __init__(self):
self.data = torch.zeros(100, 3, 24, 24)
self.data[:,0:1,:,:] = 0.
self.data[:,1:2,:,:] = 1.
self.data[:,2:3,:,:] = torch.arange(0.,24., step = 1)
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=10,
num_workers=0,
shuffle=False
)
mean = 0.
nb_samples = 0.
for data in loader:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
temp = 0.
nb_samples = 0.
for data in loader:
batch_samples = data.size(0)
elementNum = data.size(0) * data.size(2) * data.size(3)
data = data.permute(1,0,2,3).reshape(3, elementNum)
temp += ((data - mean.repeat(elementNum,1).permute(1,0))**2).sum(1)/(elementNum*batch_samples)
nb_samples += batch_samples
std = torch.sqrt(temp/nb_samples)
print(mean)
print(std)
By the way, pretty much all top google results that come when googling how to calculate the mean and std with pytorch point to your answer, having it wrong can really spread misinformation, especially for an easy to miss error like that.
Also, a lot of your answers here helped me a lot learning pytorch, so in case I don’t get another chance: thank you very much.