# Intuition for batch[:,c] in the book

I am reading a book “Deep learnin with Pytorch” and the code is like below:

``````print(batch.shape) # torch.Size([4, 3, 256, 256])
batch /= 255.0

n_channels = batch.shape[1]
for c in range(n_channels):
print("smth.shape", smth.shape)
mean = torch.mean(batch[:,c])
std = torch.std(batch[:,c])
batch[:, c] = (batch[:, c] - mean) / std
``````

I can’t get the intuition of `batch[:,c]`. It returns `torch.Size([4, 256, 256])` where `4` is batch_size, `256` is `height` and `width` of an image. `:` means everything and `c` is looped over 0, 1 and 2. In the first iteration it will be `batch[:,0]`. Why is that returning `torch.Size([4, 256, 256])`?

You index into the second index position (dimension 1) of the tensor [4, 3, 256, 256], so that dimension is dropped in the resulting view.

3 Likes

Yes, you are right that gives you a tensor of `torch.Size([4, 256, 256])`.

The intuition here is that you want to normalize image before feeding it to the network. The normalization usually comprises 2 operations: make the input distribution a 0-mean and 1-std. In this piece of code you are normalizing per channel as each channel (or each RGB color) distribution might have different characteristics so you want to normalize per channels. This assumes there is a correlation between channels of the same color, which might makes sense.

However in practice I found that you can also normalize the whole batch of images aggregating all the channels without any issue:

``````# torch.Size([4, 3, 256, 256])
print(batch.shape)

# Norm
batch_mean = batch.mean()
batch_std = batch.std()
batch = (batch - batch_mean) / batch_std
``````

You can also take a look at `albumentaitons` library specifically to `albumentations.transforms.Normalization` which does the same automatically for you

1 Like

Yeah, so from the top of my head, I think the books mentions that depending on your problem type, you’d take (per-channel) averages over the entire Dataset. This is what is usually done for image classifiers (ImageNet).
I’m sure you can get good results without splitting by channel, given that the mean is something like [0.485, 0.456, 0.406] and the std [0.229, 0.224, 0.225], so they’re not that far apart.
When you use the batch, it is a bit funny because the image depends on its neighbors, I’d probably avoid that just because I’d feel dirty if I did that. On the other hand batch norm is arguably quite successful.

Best regards

Thomas

2 Likes

Yeah you are completely right! Better not to normalize the whole batch!

I took a look to my code and the normalization is being done in a subclass of `torch.utils.data.Dataset`. So you are right. It is better just to normalize by sample.