Convert tf batchnorm to pytorch

Hi, I’m trying to convert a tensorflow pretrained model to pytorch.

The network is quite simple, just with one 1dconv and one 1dbn layer

The output is the same after the 1d conv layer. However, I cannot replicate the performance/output after the batchnorm layer

More specifically, there seems to be a bug in the original tensorflow code.

The author seems to forget to add the runningmean/variance of the batchnorm layer in tensorflow to the ‘trainable variables’, so they are never really updated in the model he provided. And the is_train flag is always set to True even in evaluation.

However, when I load the model in pytorch, then set the runningmean to all zeros and running variance to all ones, the output is different nomatter I set model to train or eval.

Nonetheless, the overall performance in the tensorflow model is great. Even if I remove all the bn layers, I can get the same output in my pytorch code, and when I add batchnorm back and retrain the network in pytorch, the performance is bad, about 20 ~ 30 percent lower,

Please check this original code in tensorflow for details

And here’s my code for the problem above

I’m not that familiar with tensorflow, but are the running estimates treated as “trainable variables”? Since they are updated based on the current batch statistic, they won’t be updated using the gradient + optimizer.
Are you sure the running estimates are not updated, if they are not in these “trainable variables” list?

In your current code you are setting the weight and bias to these values: lines of code.
These parameters correspond to the gamma and beta parameters from the original batchnorm paper.
However, the running_mean and running_var will be initialized by default to zeroes and ones, respectively.

@ptrblck
Well I’m also not quite familiar with tensorflow myself, and that’s the reason why I want to convert this model to pytorch in the first place.

In fact after looking into tensorflow’s doc, I think the running mean and variance
are not updated in the github’s repo, and they are not in the ‘trainable variables’ by default in tensorflow. These lines of code are missing.

  update_ops = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

What’s more , I load the pretrained model provided by the author in tensorflow, and it gives me a reasonable performance. And in fact, I can print these weights of the first batch norm layer with the following lines of code:

bn_weight = tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES, scope="ResidualBlock_1D_0/BatchNorm/    ")
sess.run(bn_weight)

And here’s the output

(Pdb) print(bn_weight)
[<tf.Variable 'ResidualBlock_1D_0/BatchNorm/beta:0' shape=(28,) dtype=float32_ref>, <tf.Variable 'ResidualBlock_1D_0/BatchNorm/moving_mean:0' shape=(28,) dtype=float32_ref>, <tf.Variable 'ResidualBlock_1D_0/BatchNorm/moving_variance:0' shape=(28,) dtype=float32_ref>]
(Pdb) sess.run(bn_weight)
[array([ 0.00602053,  0.0050063 ,  0.00634659,  0.00489874, -0.00110139,
       -0.0022672 ,  0.00327282,  0.00230329, -0.00308617,  0.0024446 ,
        0.00504038,  0.00412193, -0.00564392,  0.00558108,  0.00340598,
        0.00164128, -0.00063161,  0.00081446, -0.00207535,  0.00168148,
       -0.00086416,  0.00218814,  0.00166753, -0.00242493,  0.00044503,
       -0.00097768,  0.00121772,  0.00186761], dtype=float32), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)]

And I think I’m doing the right thing in pytorch by setting the mean to zeros and variance to ones. However, the output is just different, wired.

I also tried to set track_running_stats to False or affine to True in pytorch, no help at all. :roll_eyes::roll_eyes:

Hi, @ptrblck

A similar question about Jax and Pytorch.

For BatchNorm, I noticed that there is a difference between Jax and Pytorch.

I.e.

Jax computation order: (inputs - mean) * (scale * rsqrt(var + eps))

Pytorch: (inputs - mean) * scale * rsqrt(var + eps).

I know they are theoretically equivalent.

Do you think this computation order would make difference?

Thanks in advance.

You might be running into rounding errors using floating point numbers, which could then create a small error between both implementations, but I wouldn’t expect to see any difference in your training etc.

1 Like