How to copy network parameters in libtorch c++ API


(P Sz) #1

Hi,
My question is how to copy the values of trainable parameters from one network to another
using the libtorch c++ API.
More precisely:
I have a custom Network class derived from torch::nn::Module and two instances of this class named n1 and n2. I want to copy the trainable parameters from n2 to n1.
In pytorch this can be achieved by n1.load_state_dict(n2.state_dict()), but the network class has no such methods in the c++ API. (Seems to be related to the clone merhod, but I could not figure out how to use it)


(Kirill) #2

Hi, as I understand now you should implement load_state_dict function for c++ manually, it will looks like:

 torch::autograd::GradMode::set_enabled(false);  // make parameters copying possible
 auto new_params = ReadStateDictFromFile(params_path); // implement this
 auto params = model->named_parameters(true /*recurse*/);
 auto buffers = model->named_buffers(true /*recurse*/);
 for (auto& val : new_params) {
     auto name = val.key();
     auto* t = params.find(name);
     if (t != nullptr) {
          t->copy_(val.value());
      } else {
          t = buffers.find(name);
          if (t != nullptr) {
            t->copy_(val.value());
          }
      }
}
torch::autograd::GradMode::set_enabled(true);

(P Sz) #3

Thanks a lot Kirill, I managed to implement a similar solution, but thanks for posting me, this code is neater than mine.


(Afshin Oroojlooy) #4

Are you gonna add something like load_state_dict in the next version?


(Kirill) #5

I’m not a PyTorch developer, but you can look at my implementation here


(Afshin Oroojlooy) #6

Thanks, I also implemented a function for myself, though it would be nicer if libtorch had a built-in function.