Intuition for batch[:,c] in the book

I am reading a book “Deep learnin with Pytorch” and the code is like below:

print(batch.shape) # torch.Size([4, 3, 256, 256])
batch /= 255.0

n_channels = batch.shape[1]
for c in range(n_channels):
    print("smth.shape", smth.shape)
    mean = torch.mean(batch[:,c])
    std = torch.std(batch[:,c])
    batch[:, c] = (batch[:, c] - mean) / std

I can’t get the intuition of batch[:,c]. It returns torch.Size([4, 256, 256]) where 4 is batch_size, 256 is height and width of an image. : means everything and c is looped over 0, 1 and 2. In the first iteration it will be batch[:,0]. Why is that returning torch.Size([4, 256, 256])?

You index into the second index position (dimension 1) of the tensor [4, 3, 256, 256], so that dimension is dropped in the resulting view.

3 Likes

Yes, you are right that gives you a tensor of torch.Size([4, 256, 256]).

The intuition here is that you want to normalize image before feeding it to the network. The normalization usually comprises 2 operations: make the input distribution a 0-mean and 1-std. In this piece of code you are normalizing per channel as each channel (or each RGB color) distribution might have different characteristics so you want to normalize per channels. This assumes there is a correlation between channels of the same color, which might makes sense.

However in practice I found that you can also normalize the whole batch of images aggregating all the channels without any issue:

# torch.Size([4, 3, 256, 256]) 
print(batch.shape)

# Norm
batch_mean = batch.mean()
batch_std = batch.std()
batch = (batch - batch_mean) / batch_std

You can also take a look at albumentaitons library specifically to albumentations.transforms.Normalization which does the same automatically for you

1 Like

Yeah, so from the top of my head, I think the books mentions that depending on your problem type, you’d take (per-channel) averages over the entire Dataset. This is what is usually done for image classifiers (ImageNet).
I’m sure you can get good results without splitting by channel, given that the mean is something like [0.485, 0.456, 0.406] and the std [0.229, 0.224, 0.225], so they’re not that far apart.
When you use the batch, it is a bit funny because the image depends on its neighbors, I’d probably avoid that just because I’d feel dirty if I did that. On the other hand batch norm is arguably quite successful.

Best regards

Thomas

2 Likes

Yeah you are completely right! Better not to normalize the whole batch!

I took a look to my code and the normalization is being done in a subclass of torch.utils.data.Dataset. So you are right. It is better just to normalize by sample.