Hi,

so I’ve spent now many hours on trying to do the following thing: Standardizing the data for use in a Dataloader. I ended up with the following code snippet which is some of the ugliest code I have written in a long time. I also spent the whole day today finding a problem in my code which was related to the fact that for MNIST `dataset.data`

is different from `x, _ = next(iter(DataLoader(dataset, batch_size=len(dataset))`

. Why. What is the point of dataset.data at all?

The main question though is: is there any better way to do this?

```
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
class Standardize:
def __init__(self, mean, var):
self.mean = mean
self.var = var
def __call__(self, image):
# make sure mean and shape are vectors not scalars
assert image.shape[0] == self.mean.shape[0] and image.shape[0] == self.var.shape[0]
return (image - self.mean) / torch.sqrt(self.var)
# Load the MNIST dataset
transforms_1 = [ToTensor(), lambda x: x.view(-1)]
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose(transforms_1))
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=False)
x, _ = next(iter(train_loader))
mean_vector = x.mean(dim=0)
var_vector = x.var(dim=0)
# if we would require drop_last=True (which is required for JAX for example) for a certain batch_sizes we would have to
# change the following lines to take the dropping of the last batch into account, which means after initizaling the
# train_loader we would have to iterate through the loader concatenate all batches and then proceed to calculate the
# mean and variance
# Check if variance has any zero entries and handle it
if torch.any(var_vector == 0):
var_vector[var_vector == 0] = 1
transform_2 = transforms.Compose([*transforms_1, Standardize(mean_vector, var_vector)])
train_dataset_with_transforms = datasets.MNIST(root='./data', train=True, download=True, transform=transform_2)
batch_size = len(train_dataset)
train_dataloader = DataLoader(train_dataset_with_transforms, batch_size=batch_size, shuffle=False, drop_last=True)
# Make sure everything works as expected
x, _ = next(iter(train_dataloader))
mean, var = x.mean(), x.var()
print(f"Mean should be zero: {mean:.4f}")
# prints out 0
print(f"Variance should be one: {var:.4f}")
# prints out 0.9145
```