Can I Creat Model during Runtime?

Currently the examples in C++ such as the mnist, and the custom-dataset, dcgan creating the models in the C++ and compile to run the training and validation. This means the model is “fixed” once the source is compiled and any changes on the network structure would require to recompile the code again.

Any examples to have a compiled executable to accept network model such as mnist CNN in a file format (json/prototxt/etc…) and trained the model from the external file? This will allow user to modify the network structure in the file and retrain the model without recompiling the code.

A close example is the caffe (not caffe2) which have the model defined in prototxt.

Thanks.

Regard,
CL

One way to achieve this is by defining the model in TorchScript and then training it in C++:

  1. Make sure your PyTorch and libtorch installation are of the same version / commit. One way to ensure it is to install the nightly builds:

    1. PyTorch: pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cu90/torch_nightly.html (change “cu90” to your CUDA version)
    2. libtorch: wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip &&
      unzip libtorch-shared-with-deps-latest.zip
  2. Run the following Python script to export a model:

    import torch
    
    class Model(torch.jit.ScriptModule):
        def __init__(self):
            super(Model, self).__init__()
            self.linear = torch.nn.Linear(2, 2)
    
        @torch.jit.script_method
        def forward(self, x):
            a = 0
            while a < 4: # While loop example
                print(a)
                a += 1
            x = self.linear(x)
            return x
    
    model = Model()
    model.save('model1.pt')
  1. Create a cpp file called example-app.cpp that contains the following:
    #include <torch/torch.h>
    #include <torch/script.h>
    #include <iostream>
    
    int main() {
      std::string module_path = "../model1.pt";
      // NOTE: depending on the libtorch version, you might need to change this line to `torch::jit::script::Module module = torch::jit::load(module_path);`
      // See https://discuss.pytorch.org/t/how-to-load-previously-saved-model-in-cpp-frontend/52336/6 for details.
      std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(module_path);
    
      // TorchScript module's forward method only accepts std::vector<torch::jit::IValue> as input
      std::vector<torch::jit::IValue> inputs;
      inputs.push_back(torch::randn({2, 2}).set_requires_grad(true));
    
      std::cout << "gradient of self.linear.weight: " << "\n";
      std::cout << module->get_module("linear")->get_parameter("weight").grad() << "\n";
    
      torch::Tensor prediction = module->forward(inputs).toTensor();
      prediction.sum().backward();
    
      // Verify that backward works
      std::cout << "gradient of self.linear.weight: " << "\n";
      std::cout << module->get_module("linear")->get_parameter("weight").grad() << "\n";
    }
  1. Create CMakeLists.txt, put in the following:
    cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
    project(example-app)
    
    find_package(Torch REQUIRED)
    
    add_executable(example-app example-app.cpp)
    target_link_libraries(example-app "${TORCH_LIBRARIES}")
    set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
  1. mkdir build && cd build
  2. cmake -DCMAKE_PREFIX_PATH=/absolute/path/to/libtorch .. # /absolute/path/to/libtorch should be the absolute path to the unzipped LibTorch distribution
  3. make
  4. Run example-app to see output

Later on, when you want to modify the network structure of your model, you can change the model description code in Python, export the model again, and then re-train the model using the same C++ code. This way you don’t need to recompile the C++ code.

2 Likes

Hi, Thanks for your great explanation, it will be one of the solution.

Just something more demanding perhaps, is it possible to create the model.pt then from a simple text editor w/o using the python? it will be good to have everything in C++. The text file which contains the model architecture could be as simple as normal text, or json, or google prototxt format?

Thanks again.

Regards,
Chin Luh