Computing the mean and std of dataset

This is my solution:

mean = 0.0
meansq = 0.0
count = 0

for index, data in enumerate(train_loader):
    mean = data.sum()
    meansq = meansq + (data**2).sum()
    count += np.prod(data.shape)

total_mean = mean/count
total_var = (meansq/count) - (total_mean**2)
total_std = torch.sqrt(total_var)
print("mean: " + str(total_mean))
print("std: " + str(total_std))
4 Likes