Using TorchTRT-FX backend on C++

I read about the torch-trt package, the ability to convert a torch model to torchscript and then running it on C++. There are examples and I understand it.

Anyway, we have the other option, to trace the model using torch_tensorrt.fx, and then saving it to the disk. The .pt file contains the whole sub-graphs which are TRTModules, or GraphModule.
Can I load this .pt(or any other file type) to the torch_tensorrt C++ API, and use it there?
The torchscript doesn’t provide an API for splitting the model to sub-graphs and so on. So I want to use the fx flexibility for my runtime environment(on C++)

The package should work on C++, but the fx(as I understood) is a python only tool. And I can’t find an example for loading the .pt which created by torch_tensorrt.fx.

CC @narendasan as the expert on TorchTRT

@daniellevi With the new version of Torch-TensorRT (1.3.0) you can use the FX frontend to compile your module to target the torchscript/tensorrt runtime using use_experimental_fx_rt. It produces a representation which can be passed to torch.jit.trace which can be saved and loaded in C++ and run in the same way as a module compiled with the TorchScript frontend.

For example:

model_fx = model_fx.cuda()
inputs_fx = [i.cuda() for i in inputs_fx]
trt_fx_module_f16 = torch_tensorrt.compile(
    model_fx,
    ir="fx",
    inputs=inputs_fx,
    enabled_precisions={torch.float16},
    use_experimental_fx_rt=True,
    explicit_batch_dimension=True
)

# Save model using torch.save 

torch.save(trt_fx_module_f16, "trt.pt")
reload_trt_mod = torch.load("trt.pt")

# Trace and save the FX module in TorchScript
scripted_fx_module = torch.jit.trace(trt_fx_module_f16, example_inputs=inputs_fx)
scripted_fx_module.save("/tmp/scripted_fx_module.ts")
scripted_fx_module = torch.jit.load("/tmp/scripted_fx_module.ts") #This can also be loaded in C++
1 Like

I will try it.
Thank you!

Hi
I tried what you wrote, but the loading fails.
The way I’m trying to load the model is by using this line-
torch::jit::Module m_module = torch::jit::load(m_modelFile)

Is it correct? or is there other way to load it?
Maybe you know about an example of a script which saves a model after torch.fx processing, and loads it into the C++ API?

Thanks

We dont have a specific example for this but this is correct. Once the fx traced torchscript is saved to disk following the workflow above, the process is the same as other torchscript so you can follow this tutorial: TensorRT/Resnet50-CPP.ipynb at main · pytorch/TensorRT · GitHub