Accessing weights scripted module

I’m trying to access the weights of a scripted module. I load the module as follows:

torch::jit::script::Module model; model = torch::jit::load(argv[1]);
Where the first argument is a .pt file.

I can then print the parameters by iterating over them:
for (const auto& params : model.parameters()) { std::cout << params << std::endl; }

But how can I then alter this params tensor? Say I want to add ones to each value in the tensor. I can’t alter the params variable as it is a constant. But if I remove const I get an error that this construction of iterating doesn’t work without consts.

You could use the .copy_ operation as in this example:

torch::NoGradGuard guard;
for (auto& p : model->parameters()) {
  auto sz = p.view(-1).size(0);
  p.copy_(torch::cos(torch::arange(0, sz, tensor_options).view(p.sizes())));
}
1 Like