BatchNorm2d question

Greetings!

I am confused by the Batchnorm2d documentation:
https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html

The mean and standard-deviation are calculated per-dimension over the mini-batches

What does per dimension mean?

As far as I know, I want to normalize every feature map, i.e. every channel. But how exactly is a single feature map normalized? They say that \gamma and \beta have the same dimension as the channel number C, so y also has dimension C. Yet, the layer output is supposed to have the same shape as the input… what are y and x in this formula then?

Best,
PiF

Say you have a batch of N RGB (3 channel) images

Step 1: Normalize the channels with respect to batch values
BatchNorm2d will calculate the mean and standard deviation values with respect to each channel, that is the mean red, mean green, mean blue for the batch. Thus you have a vector m of means and a vector s of standard deviations both of shape 3 (same as channels). You then normalize each channel for all the images using m and s (by broadcasting). Note that this step is equivalent to the “Standard Scaling” preprocessing that is commonly used for data, e.g. when you do this:

from torchvision import transforms
transform = transforms.Compose([
#                                 mean                  std
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])  

Step 2: Rescale and re-center channels using learned values \gamma and \beta
For each image in your normalized batch, rescale the values in each channel using \gamma, then recenter each channel using \beta. \gamma and \beta are learned parameters, and both are of shape 3 (same as channels) and you’re done!

Code example
I have provided some illustrational code. I am doing batch norm on a batch of 32 “rgb images of resolution 4 by 4”.

import torch 
from torch import nn
from torchvision import transforms

image_batch = torch.randn(32,3,4,4)

gamma = torch.tensor([1,2,3], dtype=torch.float32)
beta = torch.tensor([34,35,69], dtype=torch.float32)

def bn(batch: torch.Tensor):
    '''
    Homemade batch norm
    '''
    # I am doing all this [:, None, None] stuff to add singleton dimensions so I can broadcast things
    batch_mean = batch.mean(dim=(0,2,3))[:, None, None]
    batch_std = batch.std(dim=(0,2,3), unbiased=False)[:, None, None]
    gamma_ = gamma[:, None, None]  
    beta_ = beta[:, None, None]
    return ((batch - batch_mean) / (batch_std)) * gamma_ + beta_

def bn_using_transform(batch: torch.Tensor):
    '''
    Batch norm using transforms.Normalize
    '''
    batch_mean = batch.mean(dim=(0,2,3))
    batch_std = batch.std(dim=(0,2,3), unbiased=False)
    gamma_ = gamma[:, None, None]  
    beta_ = beta[:, None, None]
    transform = transforms.Compose([transforms.Normalize(batch_mean, batch_std)])
    return transform(batch) * gamma_ + beta_

'''
Batch norm using nn.BatchNorm2d
'''
bn_module = nn.BatchNorm2d(3) 
bn_module.weight = nn.Parameter(gamma) # Set batch norm parameters for illustration
bn_module.bias = nn.Parameter(beta)    

out = bn(image_batch)
out_transform = bn_using_transform(image_batch)
out_module = bn_module(image_batch)

# Relax rtol criteria to accomodate different errors caused by for out of order execution 
# and also that I dont know what epsilon nn.BatchNorm2d uses
print(torch.allclose(out, out_transform, rtol=0.0001)) 
# True
print(torch.allclose(out, out_module, rtol=0.0001)) 
# True

print(out[0])
#tensor([[34.6894, 33.5971, 33.4271, 33.4252],
#        [34.0073, 34.7381, 33.6011, 33.7023],
#        [34.5690, 34.2684, 34.0345, 35.0332],
#        [34.4776, 35.2361, 33.2691, 33.3009]])
print(out_module[0])
#tensor([[34.6894, 33.5971, 33.4271, 33.4252],
#        [34.0073, 34.7381, 33.6011, 33.7023],
#        [34.5690, 34.2684, 34.0345, 35.0332],
#        [34.4776, 35.2361, 33.2691, 33.3009]], grad_fn=<SelectBackward>)
1 Like

Thank you Napam, that was insightful!

So for each input channel, we take the average over height and width and then also average over batch, so that we get a 3-element vector that basically holds the average red, the average green, and the average blue activation.

So, sticking to the documentation, the x can be seen as a 3-channel (let’s stick to RGB) image, the E[ ] and Var[ ] are taken with respect to height, width, and batch, and \gamma and \beta are applied element-wise to the 3 channels?

Yes exactly

Yes I think you got it :grin:. \gamma[0] will multiplied to channel 0 (red), \gamma[1] will multiplied to channel 1 (green) etc. Same idea with the \beta thing, except you do addition.

Thank you, I got it now!