Porting batch normalization from TF to PyTorch

A cheat sheet for porting batch normalization from TF model to Pytorch

Param Mapping
bn/gamma → bn.weight
bn/beta → bn.bias
bn/moving_mean → bn.running_mean
bn/moving_variance → bn.running_var

mean and variance are not trainable params, we will need to read them in for inference.

eps
Also set eps explicitly to 1e-3 if you are using tf default as default in pytorch is 1e-5.

2 Likes