Loading TorchScript using LibTorch exported using torch.jit.trace_module.()

Dear all,

I’m no C++ expert so forgive me if I ask something simple. But I cannot seem to figure this out.

I have a module for which I wish to export several methods to TorchScript using Python and import them in C++. For this I read from the PyTorch documentation: torch.jit.trace_module and torch.jit.trace.

When exporting things from Python using torch.jit.trace, I follow the same approach which always works perfectly. This time I really cannot use methods named forward so I have to go though torch.jit.trace_module.

I roughly use this Python code:

my_module = MyModule()

trace_inputs = {
	'method1' : self.create_dummy_input_method1()
	'method2' : self.create_dummy_input_method2()
	}

traced_torch_script = torch.jit.trace_module(my_module , trace_inputs)
traced_torch_script.save(export_path)

Exporting works :slight_smile: .

Then, I try to load this from C++ using LibTorch:

torch::jit::script::Module myModule = torch::jit::load(import_path);

This also works (at least, I get no error). Then, I try to call a method on the module:

std::vector<torch::jit::IValue> inputs;
inputs = ...
torch::Tensor output_tensor = module.forward(inputs).toTensor();

This fails, which seems logical as I did not export a forward method. (this does work if I exported the module using jit::trace instead of jit::trace_module).
However, the torch.jit.trace_module docs state:

TORCH.JIT.TRACE_MODULE

Returns

A ScriptModule object with a single forward method containing the traced code. When func is a torch.nn.Module , the returned ScriptModule will have the same set of sub-modules and parameters as func .

Is this a copy past error in the docs from the jit::trace docs? It seems not to be copy past, so no error. Then, I’m affraid I don’t understand it.

I also tried changing forward to method1:

torch::Tensor output_tensor = module.method1(inputs).toTensor();

Unfortunately, then I get compile error saying module does not have a method1 (which makes sense to me from a C++ point of view).

How can I call MyModule’s methods method1 and method2 in C++ using LibTorch?

Thank you!