Estimate the std and mean of a large dataset using Welford method

Hi,

I’m trying to estimate the std and mean of a large image dataset using a Welford algorithm. As far as I read this algorithm is more robust than the method that accumulates the sum and the sum square of the values. However, I’m facing some issues. The first one is that I tested the algorithm in a small random dataset(6 images) and the mean and std estimated by Welford method are not exactly the same as the mean and std calculated using the whole dataset at the same time. The difference is very small, but I’m not sure if I’m doing something wrong.

import torch
torch.set_printoptions(precision=10)
class ImgRunningStats:

    def __init__(self):
        self.n = torch.zeros(1, dtype=torch.long)
        self.old_m = torch.zeros(3, dtype=torch.float64)
        self.new_m = torch.zeros(3, dtype=torch.float64)
        self.old_s = torch.zeros(3, dtype=torch.float64)
        self.new_s = torch.zeros(3, dtype=torch.float64)

    def push(self, x):
        self.n += 1

        if self.n == 1:
            self.old_m = self.new_m = x
            self.old_s = 0
        else:
            self.new_m = self.old_m + (x - self.old_m) / self.n
            self.new_s = self.old_s + (x - self.old_m) * (x - self.new_m)

            self.old_m = self.new_m
            self.old_s = self.new_s

    def mean(self):
        return self.new_m if self.n else 0.0

    def var(self):
        return self.new_s / (self.n - 1) if self.n > 1 else 0.0

    def std(self):
        return torch.sqrt(self.var())
    


rs= ImgRunningStats()

test_img_0 = torch.rand([2,3,250,250])
test_img_1 = torch.rand([2,3,250,250])
test_img_2 = torch.rand([2,3,250,250])

images = [test_img_0, test_img_1, test_img_2]

test_img = torch.cat([test_img_0, test_img_1, test_img_2], dim=0)

print('mean full: ', test_img.mean(dim=[0,2,3]))
print('std full: ', test_img.std(dim=[0,2,3], unbiased=True))


for img in images:
    pixels = img.view(-1, img.shape[1])
    for px in pixels:
        rs.push(px)
       

print(f'Mean Welford: {rs.mean()}, Std Welford: {rs.std()}')

However, the major issue that I have is that the algorithm is very slow… I don’t know how I can improve the performance of the algorithm. As far as I saw, it is not feasible to run this algorithm in a large image dataset with the current implementation.

Is it possible to parallelize this algorithm?