Convert Tensorflow Model to PyTorch


I’m porting a Tensorflow model to PyTorch in order to use it as an initialization.
It mainly consists of Conv3D and LayerNorm.
I have copied the weights to my PyTorch model, but I am not sure if I have done it correctly.

I followed this guide:

And did the following:

  1. Load the weights with the following code and find their corresponding layers.
    tf_vars = []
    init_vars = tf.train.list_variables(ckpt_path)
    for name, shape in init_vars:
        array = tf.train.load_variable(ckpt_path, name)
        tf_vars.append((name, array.squeeze()))
  1. Assign the weights to the corresponding Conv3D layers
    And permute them with [4, 3, 0, 1, 2]
    My knowledge is that PyTorch stores weights in [out_channels, in_channels, depth, height, width] order.
    And Tensorflow stores weights in [depth, height, width, in_channels, out_channels] order.
    So I first permute the Tensorflow array to fit the PyTorch format.

  2. Use GroupNorm with only 1 group to simulate LayerNorm’s behavior in Tensorflow
    Because LayerNorm in PyTorch acts a bit weird for images, I use GroupNorm’s implementation instead.
    The weights (gamma) and bias (beta) are assigned accordingly.

Are the above correct? If so, then I think the problem would be in the implementation.
If not, please let me know which is incorrect.
Thanks in advance.

I was able to get a reasonable difference (around 1e-5) between the first-layer outputs of both framework by changing the permute to [4, 3, 2, 0, 1].
But the second layer generates different outputs.
I have checked that GroupNorm with group 1 is really close to LayerNorm.
The layer just looks like this:

        self.conv1_1 = nn.Sequential(
            nn.Conv3d(in_c, nf, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(1, nf)
        self.conv1_2 = nn.Sequential(
            nn.Conv3d(nf, nf*2, kernel_size=3, stride=2, padding=1),
            nn.GroupNorm(1, nf*2)

The big difference happens at the output of conv1_2.
Does anyone know how to fix this?