How to copy network parameters in libtorch c++ API

Hi,
My question is how to copy the values of trainable parameters from one network to another
using the libtorch c++ API.
More precisely:
I have a custom Network class derived from torch::nn::Module and two instances of this class named n1 and n2. I want to copy the trainable parameters from n2 to n1.
In pytorch this can be achieved by n1.load_state_dict(n2.state_dict()), but the network class has no such methods in the c++ API. (Seems to be related to the clone merhod, but I could not figure out how to use it)

1 Like

Hi, as I understand now you should implement load_state_dict function for c++ manually, it will looks like:

 torch::autograd::GradMode::set_enabled(false);  // make parameters copying possible
 auto new_params = ReadStateDictFromFile(params_path); // implement this
 auto params = model->named_parameters(true /*recurse*/);
 auto buffers = model->named_buffers(true /*recurse*/);
 for (auto& val : new_params) {
     auto name = val.key();
     auto* t = params.find(name);
     if (t != nullptr) {
          t->copy_(val.value());
      } else {
          t = buffers.find(name);
          if (t != nullptr) {
            t->copy_(val.value());
          }
      }
}
torch::autograd::GradMode::set_enabled(true);
4 Likes

Thanks a lot Kirill, I managed to implement a similar solution, but thanks for posting me, this code is neater than mine.

Are you gonna add something like load_state_dict in the next version?

1 Like

I’m not a PyTorch developer, but you can look at my implementation here

1 Like

Thanks, I also implemented a function for myself, though it would be nicer if libtorch had a built-in function.

This is an old thread, but linked recently:

These days, there also is a save / load functionality in C++, but is based on the JIT file format. So you can save a traced model (but use the model, not the function) und load weights back into your PyTorch C++ API-built model.

Best regards

Thomas

I have a model in the format hdf5 which is taught using keras pyhton. Now I want to load this model in C ++ and use it. I do not know if libtorch loads the model. How should this be done?