Correct function for BatchNorm2D

I’m trying to get the correct function of batchnorm2d.
Suppose my model is the following

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.conv1 = nn.Conv2d(1, 32, 5,padding=2)
self.pool = nn.MaxPool2d(2, 2)
self.bn1 = torch.nn.BatchNorm2d(32)
self.fc1 = nn.Linear(321414, 10)

def forward(self, x):
    x = self.bn1(self.pool(self.conv1(x))) 
    x = F.relu(x)
    x = x.view(-1, 32 * 14 * 14)
    x = self.fc1(x)
    return x

tem1 = net.pool(net.conv1(x))

Could you please tell me what is the equivalent of below equation?

bn = net.bn1(tem1)

The most closed equation I can have is

net.bn1.weight.view([32,1,1])*((tem1 - net.bn1.running_mean.view([32,1,1]))/((net.bn1.running_var.view([32,1,1]) + net.bn1.eps).sqrt())) + net.bn1.bias.view([32,1,1])

But still it has a slight difference with net.bn1(tem1)
Many thanks.

What do you mean by equivalent? As far as I understand, you want to simplify

tem1 = net.pool(net.conv1(x))
bn = net.bn1(tem1)

Which is just a combination of conv->pool->batchnorm

I want to understand what is the exact implementation inside the batchnorm2d. So I’m trying to figure out the exact equation inside the net.bn1(tem1). Equivalent means they should have the same output.

I think, running_mean and running_var are computed during training, however they are used for normalization only during testing. During training, the batch statistics (mean and std) are directly used for normalization. @ptrblck Is that right?.

As far as BatchNorm2d computation is concerned, the following works:

bn_in = ... # input tensor for batch norm

# squeeze spatial dimensions
bn_in_flat = bn_in.view(bn_in.shape[0], bn_in.shape[1], -1)

# we need mean and std per channel across whole batch. so permute the batch dimension and flatten
bn_in_flat = bn_in_flat.permute(1,0,2).view(bn_in.shape[1], -1)

# calculate mean
bn_in_mean = bn_in_flat.mean(dim=-1).view(1, bn_in.shape[1], 1,1)

# calculate unbiased std
bn_in_std = bn_in_flat.std(dim=-1, unbiased=False).view(1, bn_in.shape[1], 1,1)

# batch norm 2D computation
bn_out = ((net.bn1.weight.view(1, out_channels, 1, 1) * 
            (bn_in - bn_in_mean) / (bn_in_std**2 + net.bn1.eps).sqrt()) + 

Yes, you are correct. The running estimates are used during eval. If track_running_stats is set to False, the batch statistics will be used during training and eval.

A small correction/clarification is needed. track_running_stats is set to False?

Yes, thanks for catching this typo! :slight_smile:

Hi all, thank much for your help. I’m able to generate the same output now.

I am wondering why running_mean and running_var are not used during training. Intuitively, that would make the model more stable, and the convergence would be faster.

These running stats are updated during training and are thus set to the default values of 0s for the running_mean and 1s for the running_var. Using them during training without updating wouldn’t normalize the data.

Also, even if you update the stats during training but use them as well, I would assume your model might yield quite a bad performance, since the underlying mean and var might differ from the default values and the model would be trained on “unnormalized” data at the beginning.
Batchnorm layers deeper in the model would also get activation values, which were not created using normalized data, so your complete model could also break.

EDIT: just a quick update:
as you see there are a lot of "could/should"s in my post, so don’t let my post stop you from experimenting with this approach, if you think it can work fine. Also, please update me once you’ve run some experiments. :wink:

1 Like

After some research, I found this approach (a bit modified version) has been used in batch renorm. It is not included in pytorch yet, but some third party repos like this one are available.

1 Like

So is it correct that at test/eval time, only a single mean/std used, not the whole history that running_mean and running_var store?

The running stats are not storing a history of the values, but are updated to a single value as described in the docs using the current batch stats, the running stats, and the momentum.

OK, so there are 64 values in running_mean and running_var. Which ones are used at test time, with model.eval()?

All of these values will be used, since the running_mean and running_var contain num_features values, which corresponds to the number of input channels.

OK, I think I got it: during training only the current per-channel mean/std are used (+scale and shift factors), at test time the history (also per-channel).