jit::script::Module parameters are not updating when training

I’m trying to train a torch::jit::script::Module but the loss is not decreasing. Here is a minimal example.

#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>
#include <vector>

// custom loader code
#include "nets/nets.h"
#include "util/runfiles.h"

int main(int argc, char** argv) {
  std::cout << "Nets example" << std::endl;

  // Custom code that loads the module on CUDA
  auto runfiles = MakeRunfiles(argv[0]);
  torch::jit::script::Module script_module = LoadSegnetBackbone(*runfiles);
  script_module.train();
  std::cout << "Loaded script module" << std::endl;

  // Pull parameters out of the script module so we can push them into the
  // optimizer.
  std::vector<at::Tensor> parameters;
  for (const auto& parameter : script_module.get_parameters()) {
    parameters.push_back(parameter.value().toTensor());
  }
  torch::optim::SGD optimizer(std::move(parameters), /*lr=*/0.01);

  constexpr int kBatchSize = 1;
  for (int epoch = 1; epoch <= 1000; ++epoch) {
    optimizer.zero_grad();

    // The input is a (kBatchSize,3,300,300) tensor filled with ones
    at::Tensor input = torch::ones({kBatchSize, /*channels (rgb) =*/3,
                                    /*height=*/300, /*width=*/300})
                           .to(at::kFloat)
                           .to(at::kCUDA);

    // Push the input through the script module
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(input);
    at::Tensor script_module_forward = script_module.forward(inputs).toTensor();
    // The result is an output tensor of size (kBatchSize, 32, 300, 300)

    // ground truth is a (kBatchSize, 300, 300) tensor filled with ones
    at::Tensor ground_truth =
        torch::ones({kBatchSize, /*height=*/300, /*width=*/300})
            .to(at::kLong)
            .to(at::kCUDA);

    at::Tensor loss = torch::nll_loss2d(
        torch::log_softmax(script_module_forward, /*dim=*/1), ground_truth);
    loss.backward();
    optimizer.step();

    if (epoch % 50 == 0) {
      std::cout << "Loss was " << loss.item<float>() << std::endl;
    }
  }
}

And the output is

Nets example
Loaded script module
Loss was 3.44751
Loss was 3.44751
Loss was 3.44751
Loss was 3.44751
Loss was 3.44751
Loss was 3.44751
Loss was 3.44751
Loss was 3.44751
...

Am I doing something wrong here?

1 Like

I’m not sure about the c++ semantic but are you sure that parameters.push_back(parameter.value().toTensor()); is actually storing a reference to the parameter or is it storing a clone of it?

1 Like

Oh that’s a good point. Do you know if there is any resource that explains the correct way to put the parameters into the optimizer?

  torch::optim::SGD optimizer(
      model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));

(from Peter Goldsborough’s example repo).

Thanks @tom, but I don’t think the torch::jit::script::Module has a .parameters() member function.

Right, sorry. Does at::Tensor --> torch::Tensor or even autograd::Variable help?

Thanks for the suggestion. at::Tensor->torch::Tensor didn’t help. I’m not sure where to use autograd::Variable.

In the same place. (Just a guess from looking at jit::script::Module).

I got a compile error from doing that since the optimizer’s constructor takes a tensor

bazel-out/k8-fastbuild/bin/third_party/libtorch/_virtual_includes/libtorch_csrc/torch/optim/sgd.h: In instantiation of 'torch::optim::SGD::SGD(ParameterContainer&&, const torch::optim::SGDOptions&) [with ParameterContainer = std::vector<torch::autograd::Variable>]':
nets/nets_example_2.cpp:25:65:   required from here
bazel-out/k8-fastbuild/bin/third_party/libtorch/_virtual_includes/libtorch_csrc/torch/optim/sgd.h:36:24: error: no matching function for call to 'torch::optim::Optimizer::Optimizer(std::vector<torch::autograd::Variable>)'
         options(options) {}
                        ^
In file included from bazel-out/k8-fastbuild/bin/third_party/libtorch/_virtual_includes/libtorch_csrc/torch/optim/adagrad.h:4,
                 from bazel-out/k8-fastbuild/bin/third_party/libtorch/_virtual_includes/libtorch_csrc/torch/optim.h:3,
                 from bazel-out/k8-fastbuild/bin/third_party/libtorch/_virtual_includes/libtorch_csrc/torch/all.h:7,
                 from bazel-out/k8-fastbuild/bin/third_party/libtorch/_virtual_includes/libtorch_csrc/torch/torch.h:3,
                 from nets/nets_example_2.cpp:2:
bazel-out/k8-fastbuild/bin/third_party/libtorch/_virtual_includes/libtorch_csrc/torch/optim/optimizer.h:37:12: note: candidate: 'torch::optim::detail::OptimizerBase::OptimizerBase(std::vector<at::Tensor>)'
   explicit OptimizerBase(std::vector<Tensor> parameters);

hi, did you solved it? i have the same problem.

I have not solved it yet. There is an open issue for it on github https://github.com/pytorch/pytorch/issues/28478 maybe you can give a thumbs up on it to give it more attention.

I had an idea which was manually call the set parameters function after the gradient step. I haven’t tried it yet.

Thanks, I have not tried it yet. There are some codes of load script model into torch::nn::Module on github https://github.com/kerry-Cho/transfer-learning-Libtorch, maybe help you in some ways.

Interesting. If I understand correctly, one way of training the jit model is by defining the identical header file in C and then load the jit model with load instead of jit load.

However, I am also suffering by just using the jit model and train it. The example that @markl shows loaded the parameters, from jit model to the optimizer, and then update it with optimizer. I am wondering whether this could be done as well and whether this is the direction of pytorch.

If this could be done, I think it will be great as any other software could just call the libtorch and perform training and inferencing, making it outstanding compare to tensorflow.

just my 2 cents, apologize me if I have made any statements wrong above.

rgds,
CL

1 Like

Same problem here.

I also found a similar discussion:

Any update?