I have the following Python code:
import torch.nn as nn
import torch as pt
import myCppModule
class Model(nn.Module):
def __init__(self):
super().__init__()
model = pt.jit.script(Model())
myCppModule.myCppFoo(model)
I want to implement some function in C++, myCppFoo, that takes “model” as argument. And then expose it to python with something like pybind11
I tried to implement this with:
#include<torch/torch.h>
#include<pybind11/pybind11.h>
#include<iostream>
void myCppFoo(const torch::jit::Module& module){
std::cerr<<module.get_properties().front().name<<std::endl;
}
PYBIND11_MODULE(myCppModule, m) {
m.def("myCppFoo", &myCppFoo);
}
However, when compiling the python library and running the python example, I get an error:
Traceback (most recent call last):
File "/home/user/test/ex.py", line 10, in <module>
myCppModule.myCppFoo(model)
TypeError: myCppFoo(): incompatible function arguments. The following argument types are supported:
1. (arg0: torch::jit::Module) -> None
Invoked with: RecursiveScriptModule(original_name=Model)
Clearly the types are not compatible. How can I modify the C++ function to accept the type provided by torch.jit.script?
The docs suggest to serialize the module to a file in Python and deserialize in C++, but I would like to avoid writing an intermediate file.