Computing the mean and std of dataset

import numpy as np
from PIL import ImageStat

class Stats(ImageStat.Stat):
  def __add__(self, other):
    # add self.h and other.h element-wise
    return Stats(list(np.add(self.h, other.h)))
1 Like

If I am training my model on a batch size of 4, should I compute the mean and std according to a batch size of 4? or is it more accurate to compute mean and std based on bigger batches (like 8) and then train my model on a batch size of 4?

Thanks.

Finally, do we know a good method to calculate mean and std?

Any batch_size should work. Training batch_size isn’t directly related to the batch_size you use for calculating mean and std.
You could choose 4 for both, or choose 4 and 8.

Very very late. I think this one is (almost) mathematically correct.

Instead of center crop one could run count a number of pixels, like pixel_count += images.nelement() if the image sizes are different.

dataset = datasets.ImageFolder('train', transform=transforms.Compose([transforms.ToTensor()]))

loader = data.DataLoader(dataset,
                         batch_size=10,
                         num_workers=0,
                         shuffle=False,
                         drop_last=False)

mean = 0.0
for images, _ in loader:
    batch_samples = images.size(0) 
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
mean = mean / len(loader.dataset)

var = 0.0
pixel_count = 0
for images, _ in loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)
    var += ((images - mean.unsqueeze(1))**2).sum([0,2])
    pixel_count += images.nelement()
std = torch.sqrt(var / pixel_count)

The code looks good!
But there is an issue when counting num of pixels.
As we count for each channel, we should exclude the channel dimension:

pixel_count += images.nelement() / images.size(1)

The updated version:


loader = data.DataLoader(dataset,
                         batch_size=10,
                         num_workers=0,
                         shuffle=False,
                         drop_last=False)

mean = 0.0
for images, _ in loader:
    batch_samples = images.size(0) 
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
mean = mean / len(loader.dataset)

var = 0.0
pixel_count = 0
for images, _ in loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)
    var += ((images - mean.unsqueeze(1))**2).sum([0,2])
    pixel_count += images.nelement() / images.size(1)
std = torch.sqrt(var / pixel_count)