Objective = Normalize the data: subtract the mean RGB (zero mean)

X_train.size()

torch.Size([50000, 3, 32, 32])

X_train.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).size()

torch.Size([1, 3, 1, 1])

just to confirm, we are flattening X_train over dim 0,2,3.

the reason we skipped dim=1 because we want to preserve the mean of each of the RGB channels?

Thanks