How would I do load_state_dict in C++?

Hello!
I would like to clone my network at every x-iteration. This is being done in python with the following code:
target_net.load_state_dict(policy_net.state_dict()) however, these functions do not seem to exist in C++?

Code from:
Reinforcement Learning (DQN) Tutorial

So what I am looking for is the equivalent to target_net.load_state_dict(policy_net.state_dict()) in C++

I think the philosophy is to compile your model using PyTorch’s JIT and then loading it in CPP.

But, you can always redesign your architecture in CPP using the torch::nn namespace and then load a text-serialized version of the weights.

Maybe I should have been more clear, both models are in C++. I do not have a model in python. That code-snippet is from Reinforcement Learning (DQN) Tutorial. and I would like to do the equivalent in C++

The current implementation of load_state_dict is in Python, and basically it parses the weights dictionary and copies them into the model’s parameters.

So I guess you’ll need to do the same in CPP.

In most cases the following code works, though it is not as comprehensive as load_state_dict of Pytorch in Python. Beforehand, you need to store the state_dict as a dictionary type (Dict[str, Tensor]) in python.

// Model class is inherited from public nn::Module
std::vector<char> Model::get_the_bytes(std::string filename) {
    std::ifstream input(filename, std::ios::binary);
    std::vector<char> bytes(
        (std::istreambuf_iterator<char>(input)),
        (std::istreambuf_iterator<char>()));

    input.close();
    return bytes;
}

void Model::load_parameters(std::string pt_pth) {
  std::vector<char> f = this->get_the_bytes(pt_pth);
  c10::Dict<IValue, IValue> weights = torch::pickle_load(f).toGenericDict();

  const torch::OrderedDict<std::string, at::Tensor>& model_params = this->named_parameters();
  std::vector<std::string> param_names;
  for (auto const& w : model_params) {
    param_names.push_back(w.key());
  }

  torch::NoGradGuard no_grad;
  for (auto const& w : weights) {
      std::string name = w.key().toStringRef();
      at::Tensor param = w.value().toTensor();

      if (std::find(param_names.begin(), param_names.end(), name) != param_names.end()){
        model_params.find(name)->copy_(param);
      } else {
        std::cout << name << " does not exist among model parameters." << std::endl;
      };

  }
}
1 Like

Also in order to load buffer tensors, e.g. batch norm running mean and var, replicate the steps for named_buffers().