Hi! Maybe this isn’t the right place for this question. I am trying to write a PyBind11 Torch extension that takes a ScriptModule as input. I have some function like this:
torch::jit::script::Module preprocess(const torch::jit::script::Module& mod) {
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_txp");
return new_mod;
}
void add_submodule(pybind11::module_& m) {
m.def("preprocess",
&preprocess,
"Preprocesses the TorchScript module",
"model"_a);
}
This compiles fine, but when I call the function from Python I get this error:
TypeError: func_name(): incompatible function arguments. The following argument types are supported:
1. (model: torch::jit::Module) -> torch::jit::Module
Invoked with: <torch._C.ScriptModule object at 0x7fc27d657770>
How do I resolve this?
More context, this works fine:
torch::Tensor preprocess(const torch::Tensor& mod) { return mod; }
PYBIND11_MODULE(MODULE_NAME, m) { m.def("preprocess", &preprocess); }
Here’s my CMake file:
cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
project(txp-cpp-backend LANGUAGES CXX C)
set(LIB_NAME cpp)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wformat")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=cpp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wformat-security")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMODULE_NAME=lib${LIB_NAME}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -O3 -Wno-reorder")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
find_package(PythonLibs REQUIRED)
find_library(TORCH_PYTHON_LIBRARY torch_python
PATHS "${TORCH_INSTALL_PREFIX}/lib")
message(STATUS "TORCH_PYTHON_LIBRARY: ${TORCH_PYTHON_LIBRARY}")
add_library(${LIB_NAME} SHARED txp.cpp)
target_compile_features(${LIB_NAME} PRIVATE cxx_std_14)
set_target_properties(${LIB_NAME} PROPERTIES OUTPUT_NAME ${LIB_NAME})
set_target_properties(${LIB_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(${LIB_NAME}
PROPERTIES LINK_FLAGS "-Wl,-rpath,${CMAKE_INSTALL_RPATH}")
target_compile_options(
${LIB_NAME} PUBLIC -Wall -Wno-sign-compare -Wno-unused-function
-Wno-unknown-pragmas)
target_include_directories(
${LIB_NAME}
PUBLIC ${TORCH_INCLUDE_DIRS}
PUBLIC ${PYTHON_INCLUDE_DIRS})
target_link_libraries(
${LIB_NAME}
PUBLIC ${PYTHON_LIBRARIES}
PUBLIC ${TORCH_LIBRARIES}
PUBLIC ${TORCH_PYTHON_LIBRARY})
And this guy:
>>> import torch; torch.__version__
'1.11.0.dev20211121'
When I try to add #include <torch/csrc/jit/python/script_init.h>
I get:
fatal error: torch/csrc/generic/Storage.h: No such file or directory
which is indeed missing from the Torch install path
What does the Python code look like? What I suspect is happening is: torch.jit.ScriptModuleis a Python wrapper; the actual pybinded type is actually stored in the attribute
_c` on it.
Something like this:
script_model = torch.jit.load(cfg.model_file, map_location="cpu")
preprocessed_model = cpplib.preprocess(script_model._c)
Yep, I am using the ._c
guy rather than the original module
Ok, turns out I was missing some important -DPYBIND11_
directives. I’ve created a minimal example repo here for posterity’s sake: GitHub - codekansas/torchscript-example: Example CMake project for TorchScript