TF 1.x batch_normalization to PyTorch BatchNorm2d

I’m trying to convert an old TensorFlow 1.x model to PyTorch. I have created a script to export the TF model weights and load them to the PyTorch model. However, I still have problems with the batch normalization. This is my old source code:

def cnn_block(inputs, filters, kernel_size):
    cnn = tf.layers.conv2d(
        inputs=inputs,
        filters=filters,
        kernel_size=kernel_size,
        padding="same",
        activation=None,
        use_bias=False
    )
    output = tf.layers.batch_normalization(
        inputs=cnn,
        momentum=0.9,
        epsilon=1e-5,
        center=True,
        scale=True,
    )
    return output

I have implemented the following PyTorch code:

import torch
from torch import nn

class CNNBlock(nn.Module):

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                bias=False
            ),
            nn.BatchNorm2d(
                num_features=out_channels,
                eps=1e-5,
                momentum=0.9,
                affine=True
            )
        )

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        return self.model(tensor)

However, the nn.BatchNorm2d() method returns different results than tf.layers.batch_normalization() method.

In the load weight routine, I copy the gamma value to model.weight, beta value to model.bias, moving_mean value to model.running_mean, and moving_variance value to model.running_var. However, the results are quite differents.

If I comment the batch normalization in both source codes, the Conv2d() returns basically the same value.

The momentum definition in both frameworks differ.
While TF uses a default of 0.99 as it’s multiplied with the moving stats:

moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
moving_var = moving_var * momentum + var(batch) * (1 - momentum)

PyTorch uses a default of 0.1 as it’s multiplied with the new observed value:

This momentum argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is x^new=(1−momentum)×x^+momentum×xt​, where x^ is the estimated statistic and xt​ is the new observed value.

1 Like

@ptrblck, thanks for the fast reply.
Therefore, must I implement a new nn.Module to normalize in the same way as in TF?

You could change the momentum via subtracting it from 1 first and check if the results would match.

As I understood, the momentum is used only during the training, right?

I’m exporting a trained model from TF to PyTorch. Therefore, the momentum doesn’t have any difference in the normalization calculation.

In my script, I load the trained TF weights using pickle package, and copy the weights to the PyTorch model. However, when I send the same tensor to both models (TF and PyTorch), the normalization results differ. This is the code I’m using to load the trained weights:

data = pickle.load(FILENAME)

model = tf_model.layer[1].model[2]
tensor = torch.tensor(
    data["normalization_layer/gamma:0"].copy(),
    requires_grad=True, dtype=model.weight.dtype
)
model.weight.copy_(tensor)
tensor = torch.tensor(
    data["normalization_layer/beta:0"].copy(),
    requires_grad=True, dtype=model.bias.dtype
)
model.bias.copy_(tensor)

I found a solution! The problem was when I exported the model to ONNX format. I must set the TrainingMode to EVAL. Both TensorFlow 1.x and PyTorch models calculate the same values (almost because there is a slight difference in the float numbers). This was my final solution:

data = pickle.load(FILENAME)

model = tf_model.layer[1].model[2]
tensor = torch.tensor(
    data["normalization_layer/gamma:0"].copy(),
    requires_grad=True, dtype=model.weight.dtype
)
model.weight.copy_(tensor)
tensor = torch.tensor(
    data["normalization_layer/beta:0"].copy(),
    requires_grad=True, dtype=model.bias.dtype
)
model.bias.copy_(tensor)
tensor = torch.tensor(
    data["normalization_layer/moving_mean:0"].copy(),
    requires_grad=True, dtype=model.running_mean.dtype
)
model.running_mean.copy_(tensor)
tensor = torch.tensor(
    data["normalization_layer/moving_variance:0"].copy(),
    requires_grad=True, dtype=model.running_var.dtype
)
model.running_var.copy_(tensor)

batch_size = 1
tensor = torch.randn(batch_size, C, H, W, requires_grad=True)

torch.onnx.export(
    model,
    tensor,
    f"model.onnx",
    export_params=True,
    opset_version=17,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.EVAL,
    input_names = ["input"],
    output_names = ["output"],
    dynamic_axes={
        "input" : {0 : "batch_size"},
        "output" : {0 : "batch_size"}
    }
)