Libtorch Sequential model is not consistent with pytorch Sequential model

Hi, I have a very simple code example that demonstrates that identical implementations of a sequential model in both libtorch and pytorch have inconsistent weights and biases. This is performed on the same computer with the same updated version of libtorch/pytorch:

Python:

import random
import numpy as np
import torch

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
print('PyTorch Version:', torch.__version__)

seqModel =  torch.nn.Sequential(
            torch.nn.Linear(2, 2),
            torch.nn.Tanh(),
            torch.nn.Linear(2, 2),
            torch.nn.Tanh(),
            torch.nn.Linear(2, 1)
        )
print('seqModel:\n', seqModel)
for name, param in seqModel.named_parameters():
    print('name:', name, 'size:', param.data.size(), 'data:\n', param.data)

C++

std::cout << "PyTorch version: "
        << TORCH_VERSION_MAJOR << "."
        << TORCH_VERSION_MINOR << "."
        << TORCH_VERSION_PATCH << std::endl;

        srand(1);
        torch::manual_seed(1);
        at::globalContext().setDeterministicCuDNN(true);
        at::globalContext().setDeterministicAlgorithms(true, false);
        at::globalContext().setBenchmarkCuDNN(false);

        auto net = torch::nn::Sequential(
            torch::nn::Linear(2, 2),
            torch::nn::Tanh(),
            torch::nn::Linear(2, 2),
            torch::nn::Tanh(),
            torch::nn::Linear(2, 1)
        );

        std::cout << net << std::endl;

        for (auto& p : net->named_parameters()) {
            std::cout << p.key() << std::endl;
            std::cout << p.value() << std::endl;
        }

Python output:

PyTorch Version: 1.11.0
seqModel:
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Tanh()
(2): Linear(in_features=2, out_features=2, bias=True)
(3): Tanh()
(4): Linear(in_features=2, out_features=1, bias=True)
)
name: 0.weight size: torch.Size([2, 2]) data:
tensor([[ 0.3643, -0.3121],
[-0.1371, 0.3319]])
name: 0.bias size: torch.Size([2]) data:
tensor([-0.6657, 0.4241])
name: 2.weight size: torch.Size([2, 2]) data:
tensor([[-0.1455, 0.3597],
[ 0.0983, -0.0866]])
name: 2.bias size: torch.Size([2]) data:
tensor([0.1961, 0.0349])
name: 4.weight size: torch.Size([1, 2]) data:
tensor([[ 0.2583, -0.2756]])
name: 4.bias size: torch.Size([1]) data:
tensor([-0.0516])

C++ output:

PyTorch version: 1.11.0
torch::nn::Sequential(
(0): torch::nn::Linear(in_features=2, out_features=2, bias=true)
(1): torch::nn::Tanh()
(2): torch::nn::Linear(in_features=2, out_features=2, bias=true)
(3): torch::nn::Tanh()
(4): torch::nn::Linear(in_features=2, out_features=1, bias=true)
)
0.weight
-0.0866 0.1961
0.0349 0.2583
[ CPUFloatType{2,2} ]
0.bias
-0.2756
-0.0516
[ CPUFloatType{2} ]
2.weight
0.3319 -0.6657
0.4241 -0.1455
[ CPUFloatType{2,2} ]
2.bias
0.3597
0.0983
[ CPUFloatType{2} ]
4.weight
0.3643 -0.3121
[ CPUFloatType{1,2} ]
4.bias
-0.1371
[ CPUFloatType{1} ]

Does anybody know what in the world is going on? It’s almost like the random number generation is offset because you can see some of the same numbers shared between the two, they’re just completely in different spots. Is this how it’s supposed to work? Apologies if it’s something simple I missed, I’ve spent days trying to figure this out.

Thank you!

1 Like

I think this is generally correct. If you seed the code and print a simple torch.randn tensor before initializing the model you would see that the Python API as well as libtorch return the same values, which points to the same logic in these calls.
However, I don’t think there is a guarantee of the same call order in the module libtorch C++ API vs. the Python API, which seems to create the offset.
If you need to use the same parameters, I would recommend to store the model or state_dict.

Hi, thank you so much for your suggestion and help. This works if the model is consistent with each run of the program, but the input and output size of my model can change depending on the context in which it’s used, so I can’t have these weights and biases imported from a file.

I found a way to define a sequential with the same weights and biases as in Python without loading it from a file. Since there is something using randomness in the sequential model creation, you can just define the layers of the model before you create the sequential model:

torch::nn::Linear inputLayer = torch::nn::Linear(2, 2));
torch::nn::Linear middleLayer = torch::nn::Linear(2, 2));
torch::nn::Linear outputLayer = torch::nn::Linear(2, 1));

And then construct the sequential model like this:

torch::nn::Sequential critic({
        {"InputLayer" , inputLayer        },
        {"Tanh1"      , torch::nn::Tanh() },
        {"MiddleLayer", middleLayer       },
        {"Tanh2"      , torch::nn::Tanh() },
        {"OutputLayer", OutputLayer       }
    });

This model will have the exact same weights and biases as the model in python.

However, I’ve recently noticed the outputs produced by the model, which has identical weights and biases to the python model, differ by no more than 0.00001% after some unit testing. But there is still a difference, meaning I can’t use Libtorch it to recreate a python sequential model, especially in the case of a reinforcement learning model where outputs directly affect the environment, and therefore future inputs into the same model.

Does this mean it’s impossible to create a sequential model with dynamic input/output sizes that is consistent in Libtorch and Pytorch?

Thank you!