Computing the mean and std of dataset

It might be a bit too late, but PIL provides nice functionality for your problem in the ImageStat.Stat class. Its calculations are based on the histogram of the images and therefore only need O(1) memory, but it only considers one image. In order to deal with more images, I extended the Stat class by introducing an __add__ method which combines two histograms of the given objects (and therefore is a bit like concatenating two images and generate the Stat object out of them):

class Stats(ImageStat.Stat):
    def __add__(self, other):
        return Stats(list(map(add, self.h, other.h)))

The histogram is stored in h, both histograms (of self and other) are summed up and then a new Stat class is initialized with the new histogram instead of an image.

Using this new “Stats” class i could do something like:

loader = DataLoader(dataset, batch_size=10, num_workers=5)

statistics = None
for data in loader:
    for b in range(data.shape[0]):
        if statistics is None:
            statistics = Stats(tf.to_pil_image(data[b]))
        else:
            statistics += Stats(tf.to_pil_image(data[b]))

And from there on use normal Stat calls like:

print(f'mean:{statistics.mean}, std:{statistics.stddev}')
# mean:[199.59, 156.30, 170.59], std:[31.30, 31.28, 35.95]

Note that although this is quite a neat solution, it is by far not the most efficient.

3 Likes