Dear all,
I’m no C++ expert so forgive me if I ask something simple. But I cannot seem to figure this out.
I have a module for which I wish to export several methods to TorchScript using Python and import them in C++. For this I read from the PyTorch documentation: torch.jit.trace_module and torch.jit.trace.
When exporting things from Python using torch.jit.trace
, I follow the same approach which always works perfectly. This time I really cannot use methods named forward
so I have to go though torch.jit.trace_module
.
I roughly use this Python code:
my_module = MyModule()
trace_inputs = {
'method1' : self.create_dummy_input_method1()
'method2' : self.create_dummy_input_method2()
}
traced_torch_script = torch.jit.trace_module(my_module , trace_inputs)
traced_torch_script.save(export_path)
Exporting works .
Then, I try to load this from C++ using LibTorch:
torch::jit::script::Module myModule = torch::jit::load(import_path);
This also works (at least, I get no error). Then, I try to call a method on the module:
std::vector<torch::jit::IValue> inputs;
inputs = ...
torch::Tensor output_tensor = module.forward(inputs).toTensor();
This fails, which seems logical as I did not export a forward method. (this does work if I exported the module using jit::trace
instead of jit::trace_module
).
However, the torch.jit.trace_module docs state:
TORCH.JIT.TRACE_MODULE
…
Returns
AScriptModule
object with a singleforward
method containing the traced code. Whenfunc
is atorch.nn.Module
, the returnedScriptModule
will have the same set of sub-modules and parameters asfunc
.
Is this a copy past error in the docs from the jit::trace
docs? It seems not to be copy past, so no error. Then, I’m affraid I don’t understand it.
I also tried changing forward
to method1
:
torch::Tensor output_tensor = module.method1(inputs).toTensor();
Unfortunately, then I get compile error saying module does not have a method1
(which makes sense to me from a C++ point of view).
How can I call MyModule
’s methods method1
and method2
in C++ using LibTorch?
Thank you!