Say you have a batch of N RGB (3 channel) images
Step 1: Normalize the channels with respect to batch values
BatchNorm2d
will calculate the mean and standard deviation values with respect to each channel, that is the mean red, mean green, mean blue for the batch. Thus you have a vector m of means and a vector s of standard deviations both of shape 3
(same as channels). You then normalize each channel for all the images using m and s (by broadcasting). Note that this step is equivalent to the “Standard Scaling” preprocessing that is commonly used for data, e.g. when you do this:
from torchvision import transforms
transform = transforms.Compose([
# mean std
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
Step 2: Rescale and re-center channels using learned values \gamma and \beta
For each image in your normalized batch, rescale the values in each channel using \gamma, then recenter each channel using \beta. \gamma and \beta are learned parameters, and both are of shape 3
(same as channels) and you’re done!
Code example
I have provided some illustrational code. I am doing batch norm on a batch of 32 “rgb images of resolution 4 by 4”.
import torch
from torch import nn
from torchvision import transforms
image_batch = torch.randn(32,3,4,4)
gamma = torch.tensor([1,2,3], dtype=torch.float32)
beta = torch.tensor([34,35,69], dtype=torch.float32)
def bn(batch: torch.Tensor):
'''
Homemade batch norm
'''
# I am doing all this [:, None, None] stuff to add singleton dimensions so I can broadcast things
batch_mean = batch.mean(dim=(0,2,3))[:, None, None]
batch_std = batch.std(dim=(0,2,3), unbiased=False)[:, None, None]
gamma_ = gamma[:, None, None]
beta_ = beta[:, None, None]
return ((batch - batch_mean) / (batch_std)) * gamma_ + beta_
def bn_using_transform(batch: torch.Tensor):
'''
Batch norm using transforms.Normalize
'''
batch_mean = batch.mean(dim=(0,2,3))
batch_std = batch.std(dim=(0,2,3), unbiased=False)
gamma_ = gamma[:, None, None]
beta_ = beta[:, None, None]
transform = transforms.Compose([transforms.Normalize(batch_mean, batch_std)])
return transform(batch) * gamma_ + beta_
'''
Batch norm using nn.BatchNorm2d
'''
bn_module = nn.BatchNorm2d(3)
bn_module.weight = nn.Parameter(gamma) # Set batch norm parameters for illustration
bn_module.bias = nn.Parameter(beta)
out = bn(image_batch)
out_transform = bn_using_transform(image_batch)
out_module = bn_module(image_batch)
# Relax rtol criteria to accomodate different errors caused by for out of order execution
# and also that I dont know what epsilon nn.BatchNorm2d uses
print(torch.allclose(out, out_transform, rtol=0.0001))
# True
print(torch.allclose(out, out_module, rtol=0.0001))
# True
print(out[0])
#tensor([[34.6894, 33.5971, 33.4271, 33.4252],
# [34.0073, 34.7381, 33.6011, 33.7023],
# [34.5690, 34.2684, 34.0345, 35.0332],
# [34.4776, 35.2361, 33.2691, 33.3009]])
print(out_module[0])
#tensor([[34.6894, 33.5971, 33.4271, 33.4252],
# [34.0073, 34.7381, 33.6011, 33.7023],
# [34.5690, 34.2684, 34.0345, 35.0332],
# [34.4776, 35.2361, 33.2691, 33.3009]], grad_fn=<SelectBackward>)