Computing the mean and std of dataset

Very very late. I think this one is (almost) mathematically correct.

Instead of center crop one could run count a number of pixels, like pixel_count += images.nelement() if the image sizes are different.

dataset = datasets.ImageFolder('train', transform=transforms.Compose([transforms.ToTensor()]))

loader = data.DataLoader(dataset,
                         batch_size=10,
                         num_workers=0,
                         shuffle=False,
                         drop_last=False)

mean = 0.0
for images, _ in loader:
    batch_samples = images.size(0) 
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
mean = mean / len(loader.dataset)

var = 0.0
pixel_count = 0
for images, _ in loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)
    var += ((images - mean.unsqueeze(1))**2).sum([0,2])
    pixel_count += images.nelement()
std = torch.sqrt(var / pixel_count)