I am trying to convert a pre-trained TensorFlow model to PyTorch and evaluate it. I import the TensorFlow model and then convert each layer into Numpy arrays. My PyTorch model has the exact same set up except all the layer shapes are flipped. I convert each Numpy array into the appropriate shape (using the torch permute function) and insert it into my PyTorch model, but the accuracy is nowhere near what it was on TensorFlow. I have a feeling it has something to do with the flipped shapes, but I’m not quite sure how to fix it. Is there something I’m missing?
What do you mean when you say that all the layer shapes are flipped? Can you give a toy example with a couple of layers?
Here is an example of the shapes of the same layer in mobileNetV2 for the tf vs pytorch model:
Tensorflow: Conv1/kernel:0: (3, 3, 3, 32)
Pytorch: features.0.0.weight torch.Size([32,3,3,3])
Most of the layers in Tensorflow are the reverse of the Pytorch layers, except the depthwise layers. Those are (2,3,0,1). For example:
Tensorflow: expanded_conv_depthwise/depthwise_kernel:0: (3, 3, 32, 1)
Pytorch: features.1.conv.0.0.weight torch.Size([32,1,3,3])
What are you using to do this import from TensorFlow to PyTorch?
We are importing a TensorFlow model, going layer by layer and converting each layer’s Tensor into a numpy array. Then we take each numpy array, convert it to a torch tensor and permute it to the “correct” shape.
The way in which TensorFlow implements a certain type of layer may be drastically different from the way in which PyTorch implements the same layer. Permuting the indices till they look like the correct shape, may not be enough.
If you want to do this import on your own, you may have to drill deep into the innards of both libraries to see how each type of layer that you are interested in, translates to tensors in each library.
Can you explain why this may be?
The model we are using is mobilenetv1, we made sure that the ordering of layers is exactly the same, names of layers correspond to each other etc. A conv2d layer should be a Conv2d layer, a depthwise conv2d is a depthwise conv2d in both. They should produce the same output on both frameworks regardless of code implementation because the operation definition is the same. Can you maybe give an example as to how this might not be the case?
I am not saying that this must be the case, just that this may be the case. From what I understand from this conversation, you already noticed that the shapes of layers are flipped across the two libraries. If the layers are implemented in essentially the same way in both libraries, why would such a non-trivial difference exist?
So—if I understood you correctly—you have discovered one type of mismatch between the two libraries’ implementation of the selfsame thing. Given this, all I am saying is that it would be prudent to not assume that things are nearly the same across the two libraries. If you wish to do a thorough job of the porting, I feel you should cross-check the implementation of each type of layer in either library.
Thats a fair point, was hoping it would be an easy port but will try this when we get some more time.
You may want to take a look at ONNX. I found a tutorial for converting a PyTorch model to TensorFlow using ONNX. The other direction should also be possible, though I couldn’t find an example in five minutes of searching.
Will take a look definitely cool! But the goal was really to do it from scratch. Maybe I can look at ONNX’s internals for how it converts them
Conv2d and BatchNorm need some attention while porting. I posted about the constant eps value handling in my message about porting batch normalization. I have ported resnet architectures from Tensorflow to Pytorch and there is no loss, the errors mostly creep in because of constants like eps values, shapes, stride implementation differences as I mentioned in this message. In my experience flipping is the NCH and NHC convention (N being batch, C being channel and H being height), if you flip everything while reading from TF model and setting it in Torch, it will work.