Hi all I’m pretty sure I’m not understanding batchnorm2d. To test my understanding I run the following code where I take a random tensor, subtract away the mean and divide by std taken across batch dimension (for each feature and pixel), spin up a batchnorm2d, initialize the bias to 0 and weights to 1, and I would expect that the batchnorm acts as identity on anything which has 0 mean and unit variance across batch dimension. I find this is not the case.
x = torch.rand(6,64,224,224) mu = x.mean(dim=0) sigma = x.std(dim=0)
(notice mu.size(), sigma.size() is (64,224,224) so we have computed the mean and std across batch for each pixel for each channel)
x_ = (x-mu)/(std+1e-5)
Okay so we expect (bn(x_)) = x_ by my understanding of batchnorm but this is not the case. Can someone please explain what is going wrong?
I first thought maybe I should nto be looking at means /std per pixel and instead compute:
mu =x.mean(dim=-1).mean(dim=-1).mean(dim=0).resize(1,64,1,1)
and similarly with sigma so that each filter receives a single mean/std; but again this was not yielding the expected result. Can somebody please comment on what I might be missing? Thank you.
BatchNorm2d computes the mean and standard deviation per channel – that is over the batch, height, and width.
import torch
import torch.nn as nn
bn = nn.BatchNorm2d(64)
bn.weight.data.fill_(1)
bn.bias.data.fill_(0)
x = torch.rand(6,64,224,224)
mu = x.mean(dim=(0,2,3), keepdim=True)
sigma = x.std(dim=(0,2,3), keepdim=True)
x_ = (x-mu)/(sigma+1e-5)
print((bn(x_) - x_).abs().max()) # bn(x_) and x_ are approximately equal
standard-deviation over multiple dimensions is a new feature. I think it’s currently only in the preview (“nightly”) builds. (You can get them from https://pytorch.org/)
Alternatively, for other PyTorch versions you can write:
import torch
import torch.nn as nn
bn = nn.BatchNorm2d(64)
bn.weight.data.fill_(1)
bn.bias.data.fill_(0)
x = torch.rand(6,64,224,224)
tmp = x.permute(1,0,2,3).reshape(64, -1)
mu = tmp.mean(dim=1).reshape(1,64,1,1)
sigma = tmp.std(dim=1).reshape(1,64,1,1)
x_ = (x-mu)/(sigma+1e-5)
print((bn(x_) - x_).abs().max()) # bn(x_) and x_ are approximately equal
i.e. you transpose and reshape x so that the batch, height, width are adjacent, combine the dimensions and then take the mean and std of that dimension