Compare pytorch and tensorflow models

Hi, what would you do if you have PyTorch and TensorFlow models that are supposed to be replicas of each other and you wanted to test that they are exact replicas in terms of layer, weight initialization, loss functions, etc? I converted a TF model into PyTorch but PyTorch accuracy is far off from that of TensorFlow. I thought about using one sample or any kind of seed to force the same sample for both models, same seed for weight initialization and then observe the values for the tensors and see where they diverge? Do you think this is a good idea and is it doable? or would you have done something else?

Two challenges I will face in that case:

  1. How to seed both models to act exactly the same?
  2. debugging is a big challenge in tf compared to PyTorch

I would appreciate your input.

Thanks!

I wouldn’t try to use seeds in order to initialize both models etc., as the implementation of the random number generator might be different in the worst case and you would waste a lot of time.

If you want to compare both models, I would start with the good performing model (i.e. the TF model in your case), create a random input and compute the output.
Once this is done, store all data in as numpy arrays (i.e. the input, output, model parameters and buffers).
Afterwards create the corresponding PyTorch model and load the parameters and buffers manually via:

with torch.no_grad():
    model.layer.parameter.copy_(torch.from_numpy(param))
    ...

perform the forward pass next and compare the outputs of the TF model and your current PyTorch implementation.
If they show a large difference, you could try to narrow down the layer, which is creating this error.

Yes theoretically this is a very good idea but I will have a hard time mapping the parameter names from TF to Pytorch since they don’t share the same names and I think I will have to do the mapping manually. My model is huge so I can imagine that this will take sometime. But that’s an idea worth trying, thank you! :slight_smile: