Using TorchTRT-FX backend on C++

I read about the torch-trt package, the ability to convert a torch model to torchscript and then running it on C++. There are examples and I understand it.

Anyway, we have the other option, to trace the model using torch_tensorrt.fx, and then saving it to the disk. The .pt file contains the whole sub-graphs which are TRTModules, or GraphModule.
Can I load this .pt(or any other file type) to the torch_tensorrt C++ API, and use it there?
The torchscript doesn’t provide an API for splitting the model to sub-graphs and so on. So I want to use the fx flexibility for my runtime environment(on C++)

The package should work on C++, but the fx(as I understood) is a python only tool. And I can’t find an example for loading the .pt which created by torch_tensorrt.fx.

CC @narendasan as the expert on TorchTRT

@daniellevi With the new version of Torch-TensorRT (1.3.0) you can use the FX frontend to compile your module to target the torchscript/tensorrt runtime using use_experimental_fx_rt. It produces a representation which can be passed to torch.jit.trace which can be saved and loaded in C++ and run in the same way as a module compiled with the TorchScript frontend.

For example:

model_fx = model_fx.cuda()
inputs_fx = [i.cuda() for i in inputs_fx]
trt_fx_module_f16 = torch_tensorrt.compile(
    model_fx,
    ir="fx",
    inputs=inputs_fx,
    enabled_precisions={torch.float16},
    use_experimental_fx_rt=True,
    explicit_batch_dimension=True
)

# Save model using torch.save 

torch.save(trt_fx_module_f16, "trt.pt")
reload_trt_mod = torch.load("trt.pt")

# Trace and save the FX module in TorchScript
scripted_fx_module = torch.jit.trace(trt_fx_module_f16, example_inputs=inputs_fx)
scripted_fx_module.save("/tmp/scripted_fx_module.ts")
scripted_fx_module = torch.jit.load("/tmp/scripted_fx_module.ts") #This can also be loaded in C++
1 Like

I will try it.
Thank you!