Different results for batchnorm with pytorch and tensorflow/ keras

I have the issue, that I use batchnorm in a multi layer case. For debug I initialized both frameworks with the same weights and bias. This works for the linear layers, I‘m not sure if it works for all the batchnorm parameters.

(1) So, how can I use batchnorm to get the same results in pytorch as in tensorflow? Because I want the model parameters from pytorch to be trained in the same way as in tensorflow. (2) Is there a way to initialize the different batchnorm implementations with the exact same parameters.

The output of the two frameworks is currently different. Below is runnable example-code. Thanks for help.

from torch import nn
import torch
import tensorflow as tf
from tensorflow import keras

class PytorchModel(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        self.linear1 = nn.Linear(in_1, out_1)
        self.norm1 = nn.BatchNorm1d(num_features=out_1)
        self.linear2 = nn.Linear(out_1, out_2)

    def do_it(self, inputs):
        x = inputs
        x = self.linear1(x)
        x = self.norm1(x)
        x = self.linear2(x)
        return x


class TensorflowModel(keras.Model):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        kernel_initializer = tf.keras.initializers.constant(0.5)
        self.linear1 = keras.layers.Dense(
            out_1, kernel_initializer=kernel_initializer, bias_initializer=kernel_initializer)
        self.norm1 = keras.layers.BatchNormalization()
        self.linear2 = keras.layers.Dense(
            out_2, kernel_initializer=kernel_initializer, bias_initializer=kernel_initializer)

    def do_it(self, inputs):
        x = inputs
        x = self.linear1(x)
        x = self.norm1(x)
        x = self.linear2(x)
        return x


input_dimension = 20
number_samples = 5

# Pytorch
torch_data = torch.ones((number_samples, input_dimension))
torch_model = PytorchModel(in_1=input_dimension, out_1=10, out_2=8)
for name, param in torch_model.named_parameters():
    values = torch.ones(param.shape) * 0.5
    param.data = values
#torch_model.eval()
torch_results = torch_model.do_it(torch_data)
print(torch_results)

# Tensorflow
tf_data = tf.ones((number_samples, input_dimension))
tf_model = TensorflowModel(in_1=input_dimension, out_1=10, out_2=8)
tf_result = tf_model.do_it(tf_data)
print(tf_result)

It should work in the same way for the parameters, since batchnorm layers use the affine .weight and .bias parameters by default. Additionally, you could set the running_mean and running_var to the same initial values (PyTorch uses zeros and ones, respectively) and check the momentum used in both frameworks.

Dear @ptrblck,
thanks for your advice. I changed the pytorch-model (updated source code is listet below), so that only the parameters of the linear layers will be set. I also changed the pytorch-batchnorm parameters to the default-tensorflow parameters (configuration is listed below).

I read in this stackoverflow article (tensorflow - Why does Keras BatchNorm produce different output than PyTorch? - Stack Overflow) that the pytorch batchnorm should be run in the eval mode (“If you run the pytorch batchnorm in eval mode, you get close results“). Using the eval mode in my use case gives the same output results. If I don‘t use the eval mode, I get different output results.

So my questions are: (1) Why does this only work with the eval mode? (2) Does this guarantee backprobagation (gradient updates and weight updates) for training neural networks? Because I thought, with the eval mode, there is no backprobagation. However, my experiments show that the weights are updated, with a minimal deviation between tensorflow and pytorch.

Batchnorm configuration:

pytorch
affine=True
momentum=0.99
eps=0.001
weights=ones
bias=zero
running_mean=zeros
running_variance=ones

tensorflow
trainable=True
momentum=0.99
eps=0.001
gamma=ones
beta=zeros
moving_mean=zeros
moving_variance=ones.

runnable example-code:

# Forum Discussion: https://discuss.pytorch.org/t/different-results-for-batchnorm-with-pytorch-and-tensorflow-keras/151691
import numpy as np
from torch import nn
import torch
import tensorflow as tf
from tensorflow import keras


class PytorchModel(nn.Module):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        # linear 1
        self.linear1 = nn.Linear(in_1, out_1)
        linear1_shape = self.linear1.weight.shape
        self.linear1.weight = torch.nn.Parameter(torch.ones(linear1_shape) * 0.5)
        self.linear1.bias = torch.nn.Parameter(torch.ones(linear1_shape[0]) * 0.5)
        # norm1
        self.norm1 = nn.BatchNorm1d(num_features=out_1, momentum=0.99, eps=0.001)
        # linear 2
        self.linear2 = nn.Linear(out_1, out_2)
        linear2_shape = self.linear2.weight.shape
        self.linear2.weight = torch.nn.Parameter(torch.ones(linear2_shape) * 0.5)
        self.linear2.bias = torch.nn.Parameter(torch.ones(linear2_shape[0]) * 0.5)

    def do_it(self, inputs):
        x = inputs
        x = self.linear1(x)
        x = self.norm1(x)
        x = self.linear2(x)
        return x


class TensorflowModel(keras.Model):
    def __init__(self, in_1, out_1, out_2):
        super().__init__()
        kernel_initializer = tf.keras.initializers.constant(0.5)
        self.linear1 = keras.layers.Dense(
            out_1, kernel_initializer=kernel_initializer, bias_initializer=kernel_initializer)
        self.norm1 = keras.layers.BatchNormalization()
        self.linear2 = keras.layers.Dense(
            out_2, kernel_initializer=kernel_initializer, bias_initializer=kernel_initializer)

    def do_it(self, inputs):
        x = inputs
        x = self.linear1(x)
        x = self.norm1(x)
        x = self.linear2(x)
        return x


# Config
number_samples = 5
input_1 = 10
output_1 = 5
output_2 = 1
optimizer_lr = 1e-4
np_data = np.array([
    [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
    [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9],
    [2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9],
    [3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9],
    [4.0, 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9]
], dtype=np.float32)

# Pytorch
torch_loss_func = nn.BCEWithLogitsLoss()
torch_data = torch.from_numpy(np_data)
torch_model = PytorchModel(in_1=input_1, out_1=output_1, out_2=output_2)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), betas=(0.5, 0.999), lr=optimizer_lr, eps=1e-07)
torch_model.eval()
torch_output = torch_model.do_it(torch_data)
torch_loss = torch_loss_func(torch_output, torch.ones_like(torch_output))
torch_optimizer.zero_grad()
torch_loss.backward()
torch_optimizer.step()
torch_gradients_5 = torch_optimizer.param_groups[0]['params'][5].grad.cpu().detach().numpy()[0]
print('*** Pytorch')
print('torch_output:\n', torch_output.cpu().detach().numpy())
print('torch_loss: ', torch_loss.cpu().detach().numpy())
print('torch_gradients[5]:', torch_gradients_5)
print('torch.linear2.weights: ', torch_model.linear2.weight.cpu().detach().numpy()[0])
print('torch.linear2.bias: ', torch_model.linear2.bias.cpu().detach().numpy()[0])

# Tensorflow
tf_data = tf.convert_to_tensor(np_data)
tf_loss_func = keras.losses.BinaryCrossentropy(from_logits=True)
tf_model = TensorflowModel(in_1=input_1, out_1=output_1, out_2=output_2)
tf_optimizer = keras.optimizers.Adam(optimizer_lr, beta_1=0.5, epsilon=1e-07)
with tf.GradientTape() as tape:
    tf_output = tf_model.do_it(tf_data)
    tf_loss = tf_loss_func(tf.ones_like(tf_output), tf_output)
tf_variables = tf_model.trainable_weights
tf_gradients = tape.gradient(tf_loss, tf_variables)
tf_optimizer.apply_gradients(zip(tf_gradients, tf_variables))
tf_gradients_5 = tf_gradients[5].numpy()[0]
print('*** Tensorflow')
print('tf_output:', tf_output.numpy())
print('tf_loss:', tf_loss.numpy())
print('tf_gradients[5]:', tf_gradients_5)
print('tf_linear2.weights:', tf_model.linear2.weights[0].numpy())
print('tf_linear2.bias:', tf_model.linear2.bias.numpy()[0])

# Assertions
grad_deviation = abs(tf_gradients_5 - torch_gradients_5)
assert grad_deviation < 1e-07, 'gradients deviation is too large: ' + str(grad_deviation)

Most likely because the momentum definition is different, as PyTorch uses:

Mathematically, the update rule for running statistics here is x^new=(1−momentum)×x^+momentum×xt, where x^ is the estimated statistic and xt is the new observed value.

and a default value of 0.1 while TF uses:

moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)

and a default value of 0.99.

.eval() does not disable the gradient calculation, but changes the normalization strategy. I.e. in training mode the batch statistics will be used to normalize the inputs and the running stats will be updated. In eval mode the running stats will be used to normalize the input.

Dear @ptrblck,
thanks for your support, this was very helpful.
Cheers Tarek.

Hi @ptrblck, sorry, I am having a similar problem converting from tensorflow to pytorch.

When I try the following code to perform inference (I guess momentum is not relevant for inference, right?)

from torch.nn import BatchNorm2d
import torch
import tensorflow as tf
import numpy as np

torch.use_deterministic_algorithms(True)
torch.backends.cudnn.enabled = False
tf.config.experimental.enable_op_determinism()
torch.set_printoptions(precision=15)
np.set_printoptions(precision=15)

x = np.random.rand(1, 64, 64, 3).astype("float32")
# Tensorflow
x_tf = tf.convert_to_tensor(x)
layer_tf = tf.keras.layers.BatchNormalization(epsilon=1e-3, moving_mean_initializer="random_normal")
outputs_tf = layer_tf(x_tf, training=False)
outputs_tf = outputs_tf.numpy()

x_pt = torch.from_numpy(np.transpose(x, (0, 3, 1, 2)))

# Pytorch
layer_pt = BatchNorm2d(
    num_features=3,
    eps=1e-3,
)
# Converting from tensorflow to pytorch
layer_pt.weight.data = torch.tensor(layer_tf.gamma.numpy())
layer_pt.bias.data = torch.tensor(layer_tf.beta.numpy())
layer_pt.running_mean.data = torch.tensor(layer_tf.moving_mean.numpy())
layer_pt.running_var.data = torch.tensor(layer_tf.moving_variance.numpy())

layer_pt.eval()
with torch.no_grad():
    h = layer_pt(x_pt)
    outputs_pt = np.transpose(h.numpy(), (0, 2, 3, 1))

print(np.array_equal(outputs_tf, outputs_pt))
print(outputs_tf[0, 0, 0])
print(outputs_pt[0, 0, 0])

I am getting

False
[0.22102581 0.28046897 0.9264546 ]
[0.22102581 0.28046894 0.92645454]

which is a small difference, but it accumulates across the network (resulting in differences >0.1 on the outputs for InceptionV3, for example). I just wanted to check if it is something that I am missing or due to the difference in the implementation, the results will be different no matter what.

The differences the posted results are [0.0000e+00, 2.9802e-08, 5.9605e-08] which is already lower than I would expect to see for floatr32 and is most likely caused by the different order of operations and thus expected.

1 Like