I am trying to use the given vgg16 network to extract features (not fine-tuning) for my own task dataset,such as UCF101, rather than Imagenet. Since vgg16 is trained on ImageNet, for image normalization, I see a lot of people just use the mean and std statistics calculated for ImageNet (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) for their own dataset.

Now I am confused. If I want to extract features with VGG16 (pretrained on the ImageNet) on my dataset, should I subtract the ImageNet mean or should I calculate my dataset’s mean and std firstly? Is it necassary?Are there big difference between the ImageNet and other RGB datasets generally?

Usually if your use case stays in the same data domain, the mean and std won’t be that different and you can try to use the ImageNet statistics.
I would recommend to use your own data statistics if you are dealing with another domain, e.g. medical images.
If you are in doubt that the ImageNet stats are a good fit for your data, I would re-calculate them on my dataset.

Thanks. I may know how to calculate the mean of my dataset, however, I do not know how to get the std of my dataset in pytorch effectively.
Can you provide me with your code to calculate the mean and std of ImageNet?I will appreciate very much!
Thanks very much!

Hi @ptrblck , the code snippet you provided calculate the standard deviation by averaging samples of the sd from mini batches. While very close to the true sd, it’s not calculated exactly. I wonder if the following would be better, albeit slower than your solution:

import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(1000, 3, 24, 24)
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
def online_mean_and_sd(loader):
"""Compute the mean and sd in an online fashion
Var[x] = E[X^2] - E^2[X]
"""
cnt = 0
fst_moment = torch.empty(3)
snd_moment = torch.empty(3)
for data in loader:
b, c, h, w = data.shape
nb_pixels = b * h * w
sum_ = torch.sum(data, dim=[0, 2, 3])
sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
cnt += nb_pixels
return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=1,
num_workers=1,
shuffle=False
)
mean, std = online_mean_and_sd(loader)

Computation of std is wrong! std dev of dataset is NOT equal to avg of std dev of batches. Authors needs to edit their answers with “Do not use this!” warning.

If you are looking for mean and std statistics for each channel of the dataset, instead of just the mean and std, I created the solution below. The code is mainly an extension of ptrblck’s code above

class MyDataset(Dataset):
def __init__(self):
self.data = torch.randn(100, 3, 24, 24)
def __getitem__(self, index):
x = self.data[index]
return x
def __len__(self):
return len(self.data)
dataset = MyDataset()
loader = DataLoader(
dataset,
batch_size=10,
num_workers=1,
shuffle=False
)
nb_samples = 0.
channel_mean = torch.Tensor([0., 0., 0.])
channel_std = torch.Tensor([0., 0., 0.])
for images in tqdm_notebook(loader):
# scale image to be between 0 and 1
images=images/255.
batch_samples = images.size(0)
images = images.view(batch_samples, images.size(1)*images.size(2), 3)
for i in range(3):
channel_mean[i]+=images[:, :,i].mean(1).sum(0)
channel_std[i]+=images[:, :,i].std(1).sum(0)
nb_samples += batch_samples
channel_mean /= nb_samples
channel_std /= nb_samples

I had a quick question about mean and std calculation for my custom dataset( around 5.5k train images). Now since these values depend on the batch size, if I were to change the batch_size (as a hyperparmeter), the mean and std need to be calculated each time the batch size changes? Since I haven’t done like this for CIFAR-10 or Imagenet

image normalization here isn’t based on batch statistics but dateset statistics. You calculate the mean and std of your entire dateset and use that to normalize each sample of the batch