Passing ScriptModule to C++ in 1.6


In 1.5, I could create a pybind11 extension with a function that takes a jitted module, like:

// c++
void go(torch::jit::Module& m) {
    auto out = m.forward(/* inputs */).toTensor().sum();
    std::cout << out << std::endl;
    m.def("go", &go);

// python
ext = cpp_extension.load(name='ext_cpp', sources=[''])

class MyMod(nn.Module):
    def forward(self, x):
        return x*x

m = torch.jit.script(MyMod())

And everything Just Worked. Now in 1.6 I can’t figure out what it wants the arg type to be. I’ve tried Module, Object, IValue, everything I can think of or see in the code. Interestingly _run_emit_module_hook is called with the RecursiveScriptModule's inner cpp_module, and it’s defined as taking a torch::jit::Module, so I thought that would work, but no dice.

Any help? :innocent:

…or if this is an xy problem, what’s the recommended way to pass a module to an extension? Grab the custom class from the registry? or just save_to_buffer and load from istringstream?