Objective = Normalize the data: subtract the mean RGB (zero mean)
torch.Size([50000, 3, 32, 32])
X_train.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).size()
torch.Size([1, 3, 1, 1])
just to confirm, we are flattening X_train over dim 0,2,3.
the reason we skipped dim=1 because we want to preserve the mean of each of the RGB channels?