How can I train a keypointrcnn_resnet50_fpn on Python and load it on C++? Or is there anyway to include it directly in C++ and train it?
You could try to torch.jit.script
the model in Python and load it in libtorch
afterwards as explained here.
I already went through the documentation. It works fine for resnet18 model, but once I save the fasterrcnn_resnet50_fpn with the same instructions and try to load it on C++, I get this error:
terminate called after throwing an instance of ‘torch::jit::ErrorReport’
what():
Unknown type name ‘NoneType’:
Serialized File “code/**torch**/torchvision/models/detection/transform.py”, line 11
image_std : List[float]
size_divisible : int
fixed_size : NoneType
~~~~~~~~ <— HERE
def forward(self: **torch**.torchvision.models.detection.transform.GeneralizedRCNNTransform,
images: List[Tensor],
Aborted (core dumped)
I’m going to share the Python, C++ and CMake files here for more information:
Python:
import torch
import torchvision
model = torchvision.models.detection.fasterrcnn_resnet50_fpn()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.script(model, example)
traced_script_module.save("rcnn_model.pt")
C++:
#include <iostream>
#include <torch/torch.h>
#include <torch/script.h>
int main()
{
torch::jit::script::Module model;
try
{
model = torch::jit::load("/home/bluesky/Documents/Codes/AI/Keypoint_RCNN/rcnn_model.pt");
}
catch (const c10::Error& e)
{
std::cerr << "Error loading the model: " << e.msg() << std::endl;
}
}
CMake:
cmake_minimum_required(VERSION 3.2)
project(torchtest)
set(CMAKE_PREFIX_PATH /home/bluesky/Downloads/libtorch-cxx11-abi-shared-with-deps-1.8.0+cu112/libtorch)
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_executable(${PROJECT_NAME} main.cpp)
target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}")
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 14)