I want to understand what is the exact implementation inside the batchnorm2d. So I’m trying to figure out the exact equation inside the net.bn1(tem1). Equivalent means they should have the same output.
I think, running_mean and running_var are computed during training, however they are used for normalization only during testing. During training, the batch statistics (mean and std) are directly used for normalization. @ptrblck Is that right?.
As far as BatchNorm2d computation is concerned, the following works:
bn_in = ... # input tensor for batch norm
# squeeze spatial dimensions
bn_in_flat = bn_in.view(bn_in.shape[0], bn_in.shape[1], -1)
# we need mean and std per channel across whole batch. so permute the batch dimension and flatten
bn_in_flat = bn_in_flat.permute(1,0,2).view(bn_in.shape[1], -1)
# calculate mean
bn_in_mean = bn_in_flat.mean(dim=-1).view(1, bn_in.shape[1], 1,1)
# calculate unbiased std
bn_in_std = bn_in_flat.std(dim=-1, unbiased=False).view(1, bn_in.shape[1], 1,1)
# batch norm 2D computation
bn_out = ((net.bn1.weight.view(1, out_channels, 1, 1) *
(bn_in - bn_in_mean) / (bn_in_std**2 + net.bn1.eps).sqrt()) +
net.bn1.bias.view(1,out_channels,1,1))
Yes, you are correct. The running estimates are used during eval. If track_running_stats is set to False, the batch statistics will be used during training and eval.
I am wondering why running_mean and running_var are not used during training. Intuitively, that would make the model more stable, and the convergence would be faster.
These running stats are updated during training and are thus set to the default values of 0s for the running_mean and 1s for the running_var. Using them during training without updating wouldn’t normalize the data.
Also, even if you update the stats during training but use them as well, I would assume your model might yield quite a bad performance, since the underlying mean and var might differ from the default values and the model would be trained on “unnormalized” data at the beginning.
Batchnorm layers deeper in the model would also get activation values, which were not created using normalized data, so your complete model could also break.
EDIT: just a quick update:
as you see there are a lot of "could/should"s in my post, so don’t let my post stop you from experimenting with this approach, if you think it can work fine. Also, please update me once you’ve run some experiments.
After some research, I found this approach (a bit modified version) has been used in batch renorm. It is not included in pytorch yet, but some third party repos like this one are available.
The running stats are not storing a history of the values, but are updated to a single value as described in the docs using the current batch stats, the running stats, and the momentum.
All of these values will be used, since the running_mean and running_var contain num_features values, which corresponds to the number of input channels.
OK, I think I got it: during training only the current per-channel mean/std are used (+scale and shift factors), at test time the history (also per-channel).