Say you have a batch of **N** RGB (3 channel) images

**Step 1: Normalize the channels with respect to batch values**

`BatchNorm2d`

will calculate the mean and standard deviation values with respect to each channel, that is the mean red, mean green, mean blue for the batch. Thus you have a vector **m** of means and a vector **s** of standard deviations both of shape `3`

(same as channels). You then normalize each channel for all the images using **m** and **s** (by broadcasting). Note that this step is equivalent to the â€śStandard Scalingâ€ť preprocessing that is commonly used for data, e.g. when you do this:

```
from torchvision import transforms
transform = transforms.Compose([
# mean std
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
```

**Step 2: Rescale and re-center channels using learned values \gamma and \beta**

For each image in your normalized batch, rescale the values in each channel using **\gamma**, then recenter each channel using **\beta**. **\gamma** and **\beta** are learned parameters, and both are of shape `3`

(same as channels) and youâ€™re done!

**Code example**

I have provided some illustrational code. I am doing batch norm on a batch of 32 â€śrgb images of resolution 4 by 4â€ť.

```
import torch
from torch import nn
from torchvision import transforms
image_batch = torch.randn(32,3,4,4)
gamma = torch.tensor([1,2,3], dtype=torch.float32)
beta = torch.tensor([34,35,69], dtype=torch.float32)
def bn(batch: torch.Tensor):
'''
Homemade batch norm
'''
# I am doing all this [:, None, None] stuff to add singleton dimensions so I can broadcast things
batch_mean = batch.mean(dim=(0,2,3))[:, None, None]
batch_std = batch.std(dim=(0,2,3), unbiased=False)[:, None, None]
gamma_ = gamma[:, None, None]
beta_ = beta[:, None, None]
return ((batch - batch_mean) / (batch_std)) * gamma_ + beta_
def bn_using_transform(batch: torch.Tensor):
'''
Batch norm using transforms.Normalize
'''
batch_mean = batch.mean(dim=(0,2,3))
batch_std = batch.std(dim=(0,2,3), unbiased=False)
gamma_ = gamma[:, None, None]
beta_ = beta[:, None, None]
transform = transforms.Compose([transforms.Normalize(batch_mean, batch_std)])
return transform(batch) * gamma_ + beta_
'''
Batch norm using nn.BatchNorm2d
'''
bn_module = nn.BatchNorm2d(3)
bn_module.weight = nn.Parameter(gamma) # Set batch norm parameters for illustration
bn_module.bias = nn.Parameter(beta)
out = bn(image_batch)
out_transform = bn_using_transform(image_batch)
out_module = bn_module(image_batch)
# Relax rtol criteria to accomodate different errors caused by for out of order execution
# and also that I dont know what epsilon nn.BatchNorm2d uses
print(torch.allclose(out, out_transform, rtol=0.0001))
# True
print(torch.allclose(out, out_module, rtol=0.0001))
# True
print(out[0])
#tensor([[34.6894, 33.5971, 33.4271, 33.4252],
# [34.0073, 34.7381, 33.6011, 33.7023],
# [34.5690, 34.2684, 34.0345, 35.0332],
# [34.4776, 35.2361, 33.2691, 33.3009]])
print(out_module[0])
#tensor([[34.6894, 33.5971, 33.4271, 33.4252],
# [34.0073, 34.7381, 33.6011, 33.7023],
# [34.5690, 34.2684, 34.0345, 35.0332],
# [34.4776, 35.2361, 33.2691, 33.3009]], grad_fn=<SelectBackward>)
```