Convering a batch normalization layer from TF to Pytorch

How to convert the following batch normalization layer from Tensorflow to Pytorch?

tf.contrib.layers.batch_norm(inputs=x, 
                             decay=0.95,
                             center=True,
                             scale=True,
                             is_training=(mode=='train'),
                             updates_collections=None,
                             reuse=reuse,
                             scope=(name+'batch_norm'))

I couldn’t find some of the following inputs in the batchnorm layer in Pytorch.

1 Like

Based on the doc, let’s try to compare the arguments.
decay seems to be 1-momentum in PyTorch.
center and scale seem to be the affine transformations, (affine in PyTorch).
is_training can be achieved by calling .train() on the Module.

I’m not sure, what updates_collection, resuse and scope mean and the docs are quite confusing for me.

Your layer would therefore look like:

bn = nn.BatchNorm2d(
    num_features=features,
    affine=True,
    momentum=0.05
).train()

PS: some arguments and properties like affine and .train() are set by default, but I’ve added them for clarification.

5 Likes