# Extremely ugly but only way to standardize dataset ?!

Hi,
so I’ve spent now many hours on trying to do the following thing: Standardizing the data for use in a Dataloader. I ended up with the following code snippet which is some of the ugliest code I have written in a long time. I also spent the whole day today finding a problem in my code which was related to the fact that for MNIST `dataset.data` is different from `x, _ = next(iter(DataLoader(dataset, batch_size=len(dataset))`. Why. What is the point of dataset.data at all?

The main question though is: is there any better way to do this?

``````import torch
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

class Standardize:
def __init__(self, mean, var):
self.mean = mean
self.var = var

def __call__(self, image):
# make sure mean and shape are vectors not scalars
assert image.shape[0] == self.mean.shape[0] and image.shape[0] == self.var.shape[0]
return (image - self.mean) / torch.sqrt(self.var)

transforms_1 = [ToTensor(), lambda x: x.view(-1)]

mean_vector = x.mean(dim=0)
var_vector = x.var(dim=0)

# if we would require drop_last=True (which is required for JAX for example) for a certain batch_sizes we would have to
# change the following lines to take the dropping of the last batch into account, which means after initizaling the
# train_loader we would have to iterate through the loader concatenate all batches and then proceed to calculate the
# mean and variance

# Check if variance has any zero entries and handle it
if torch.any(var_vector == 0):
var_vector[var_vector == 0] = 1

transform_2 = transforms.Compose([*transforms_1, Standardize(mean_vector, var_vector)])

batch_size = len(train_dataset)

# Make sure everything works as expected
`.data` is an internal attribute of this dataset which stores the raw data.
Alternatively to using the stats of a single sample to normalize the data you could try to use the `StandardScaler.partial_fit` method.