Normalize data from scratch

Objective = Normalize the data: subtract the mean RGB (zero mean)


torch.Size([50000, 3, 32, 32])

X_train.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).size()

torch.Size([1, 3, 1, 1])

just to confirm, we are flattening X_train over dim 0,2,3.
the reason we skipped dim=1 because we want to preserve the mean of each of the RGB channels?


Yes, you are skipping dim1, as you don’t want to calculate the mean in this dimension.
The code can also be written as X_train.mean([0, 2, 3], keepdim=True).

1 Like