Keypoint RCNN on C++

How can I train a keypointrcnn_resnet50_fpn on Python and load it on C++? Or is there anyway to include it directly in C++ and train it?

You could try to torch.jit.script the model in Python and load it in libtorch afterwards as explained here.

I already went through the documentation. It works fine for resnet18 model, but once I save the fasterrcnn_resnet50_fpn with the same instructions and try to load it on C++, I get this error:


terminate called after throwing an instance of ‘torch::jit::ErrorReport’
what():
Unknown type name ‘NoneType’:
Serialized File “code/**torch**/torchvision/models/detection/transform.py”, line 11
image_std : List[float]
size_divisible : int
fixed_size : NoneType
~~~~~~~~ <— HERE
def forward(self: **torch**.torchvision.models.detection.transform.GeneralizedRCNNTransform,
images: List[Tensor],

Aborted (core dumped)

I’m going to share the Python, C++ and CMake files here for more information:

Python:

import torch

import torchvision

model = torchvision.models.detection.fasterrcnn_resnet50_fpn()

# An example input you would normally provide to your model's forward() method.

example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.

traced_script_module = torch.jit.script(model, example)

traced_script_module.save("rcnn_model.pt")

C++:

#include <iostream>
#include <torch/torch.h>
#include <torch/script.h>

int main()
{

	torch::jit::script::Module model;
	
    try 
    {
    model = torch::jit::load("/home/bluesky/Documents/Codes/AI/Keypoint_RCNN/rcnn_model.pt");
    } 
    
    catch (const c10::Error& e) 
    {
        std::cerr << "Error loading the model: " << e.msg() << std::endl;
    }
}

CMake:

cmake_minimum_required(VERSION 3.2)

project(torchtest)

set(CMAKE_PREFIX_PATH /home/bluesky/Downloads/libtorch-cxx11-abi-shared-with-deps-1.8.0+cu112/libtorch)
find_package(Torch REQUIRED)
 
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
 
add_executable(${PROJECT_NAME} main.cpp)
 
target_link_libraries(${PROJECT_NAME} "${TORCH_LIBRARIES}")
 
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 14)