PyBind11 doesn't recognize `torch._C.ScriptModule` as `torch::jit::Module`

Hi! Maybe this isn’t the right place for this question. I am trying to write a PyBind11 Torch extension that takes a ScriptModule as input. I have some function like this:

torch::jit::script::Module preprocess(const torch::jit::script::Module& mod) {
  torch::jit::script::Module new_mod(mod._ivalue()->name() + "_txp");
  return new_mod;
}

void add_submodule(pybind11::module_& m) {
  m.def("preprocess",
        &preprocess,
        "Preprocesses the TorchScript module",
        "model"_a);
}

This compiles fine, but when I call the function from Python I get this error:

TypeError: func_name(): incompatible function arguments. The following argument types are supported:
    1. (model: torch::jit::Module) -> torch::jit::Module

Invoked with: <torch._C.ScriptModule object at 0x7fc27d657770>

How do I resolve this?

More context, this works fine:

torch::Tensor preprocess(const torch::Tensor& mod) { return mod; }

PYBIND11_MODULE(MODULE_NAME, m) { m.def("preprocess", &preprocess); }

Here’s my CMake file:

cmake_minimum_required(VERSION 3.12 FATAL_ERROR)

project(txp-cpp-backend LANGUAGES CXX C)

set(LIB_NAME cpp)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wformat")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=cpp")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wformat-security")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fstack-protector")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMODULE_NAME=lib${LIB_NAME}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -O3 -Wno-reorder")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
find_package(PythonLibs REQUIRED)

find_library(TORCH_PYTHON_LIBRARY torch_python
             PATHS "${TORCH_INSTALL_PREFIX}/lib")
message(STATUS "TORCH_PYTHON_LIBRARY: ${TORCH_PYTHON_LIBRARY}")

add_library(${LIB_NAME} SHARED txp.cpp)

target_compile_features(${LIB_NAME} PRIVATE cxx_std_14)
set_target_properties(${LIB_NAME} PROPERTIES OUTPUT_NAME ${LIB_NAME})
set_target_properties(${LIB_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
set_target_properties(${LIB_NAME}
                      PROPERTIES LINK_FLAGS "-Wl,-rpath,${CMAKE_INSTALL_RPATH}")
target_compile_options(
  ${LIB_NAME} PUBLIC -Wall -Wno-sign-compare -Wno-unused-function
                     -Wno-unknown-pragmas)
target_include_directories(
  ${LIB_NAME}
  PUBLIC ${TORCH_INCLUDE_DIRS}
  PUBLIC ${PYTHON_INCLUDE_DIRS})
target_link_libraries(
  ${LIB_NAME}
  PUBLIC ${PYTHON_LIBRARIES}
  PUBLIC ${TORCH_LIBRARIES}
  PUBLIC ${TORCH_PYTHON_LIBRARY})

And this guy:

>>> import torch; torch.__version__
'1.11.0.dev20211121'

When I try to add #include <torch/csrc/jit/python/script_init.h> I get:

fatal error: torch/csrc/generic/Storage.h: No such file or directory

which is indeed missing from the Torch install path

What does the Python code look like? What I suspect is happening is: torch.jit.ScriptModuleis a Python wrapper; the actual pybinded type is actually stored in the attribute_c` on it.

Something like this:

script_model = torch.jit.load(cfg.model_file, map_location="cpu")
preprocessed_model = cpplib.preprocess(script_model._c)

Yep, I am using the ._c guy rather than the original module

Ok, turns out I was missing some important -DPYBIND11_ directives. I’ve created a minimal example repo here for posterity’s sake: GitHub - codekansas/torchscript-example: Example CMake project for TorchScript