How to copy network parameters in libtorch c++ API

Hi, as I understand now you should implement load_state_dict function for c++ manually, it will looks like:

 torch::autograd::GradMode::set_enabled(false);  // make parameters copying possible
 auto new_params = ReadStateDictFromFile(params_path); // implement this
 auto params = model->named_parameters(true /*recurse*/);
 auto buffers = model->named_buffers(true /*recurse*/);
 for (auto& val : new_params) {
     auto name = val.key();
     auto* t = params.find(name);
     if (t != nullptr) {
          t->copy_(val.value());
      } else {
          t = buffers.find(name);
          if (t != nullptr) {
            t->copy_(val.value());
          }
      }
}
torch::autograd::GradMode::set_enabled(true);
4 Likes