How can I train in C++ using a Pytorch torchscript model

I trained a model in the PyTorch, and then saved it to Torchscript format using torch.jit.save.
Now, I want to retrain on this model. I have a question about whether the torchscript model can be used for training.

thanks

1 Like

Yes, you can train your model in libtorch and @krshrimali has published some blog posts with examples in this post.

1 Like

Thanks for your reply. I read @krshrimali 's blog. I have a few different questions about how to train the Torchscript model in C++.
I want to use a trained model for fine tuning. I generated the Torchscript model in pytorch. In C++ API, I load the model using torch::jit::load function. And then I want to retrain the model.
In my code:
torch::jit::script::Module m_model = torch::jit::load(m_modulePath);
torch::optim::SGD optimizer(m_model.parameters(), SGDoptions);

When I set up the optimizer, I was told that the first parameter was incorrect.

1 Like

Thanks @ptrblck for the mention. :slight_smile:

Just to understand the error better, @ChenyijunAaron - it will help if you could share the exact error here.

Thanks for your reply. :slightly_smiling_face:

I want fine-tune a model. The model is generated using pytorch. And then I load the .pt model in libtorch.
When I’m initializing the constructor of class torch::optim::SGD, the compiler prompts me that the first argument I entered does not match. See the following code:
torch::jit::script::Module m_model = torch::jit::load(m_modulePath);
torch::optim::SGDOptions SGDoptions(m_train_lr);
torch::optim::SGD optimizer(m_model.parameters(), SGDoptions);
**m_model is generated using PyTorch torch::jit::trace &torch::jit::save

But when I use the following code, I get no error. ( Use your blog code)
See the following code:
torch::nn::Linear linear_layer{ 512,2 };
torch::optim::Adam optimizer(linear_layer->parameters(), torch::optim::AdamOptions(1e-3));

Is it because the torchscript model can’t backpropagate in libtorch? If I want to train the Torchscript model, what should I do.

Thank you so much. :slightly_smiling_face:

Hi ChenyijunAaron,

Glad to discuss with you here about training or fine-tuning the python saved .pt module in C++ with libtorch. It is caused by the different types of modules: torch::jit::parameter_list and std::vectorat::Tensor.

I also meet this problem and am still debugging it.

How to fix it? By now I goes to train it with pure C++ and give up importing python. I will also try the import way when free. If you solve it, it would be great to share it here.

Enjoy the coding!