Unable to compile RAFT model in torch_tensorrt

I try several pipelines without success to convert RAFT model in pytorch-tensorRT. I had try 3 pipelines in two distinct python environment but everything fail

OS : ubuntu 20.04
Python : 3.8

environment 1 :

  • torch 2.0.1
  • torch-tensorrt 1.4.0
  • torchvision 0.15.2
  • tensorrt 8.6.1
    environment2:
  • torch 1.12.1
  • torch-tensorrt 1.2.0
  • torchvision 0.13.1
  • tensorrt 8.0.3.4

pipeline 1 : torch_tensorrt.compile

import torch
import torch_tensorrt
from torchvision.models.optical_flow import Raft_Small_Weights
from torchvision.models.optical_flow import raft_small
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device", device)
with torch_tensorrt.logging.debug():
  model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=False).to(device)
  model = model.eval()
  # raft expect two inputs images
  inputs=[
  torch.randn((1, 3, 256, 256), dtype=torch.float).to(device), # Use an example tensor and let torch_tensorrt infer settings
  torch.randn((1, 3, 256, 256), dtype=torch.float).to(device) # Use an example tensor and let torch_tensorrt infer settings
  ]
  enabled_precisions = {torch.half}  # Run with fp16

  trt_ts_module = torch_tensorrt.compile(
      model, inputs=inputs, enabled_precisions=enabled_precisions
  )
  print('conversion done !')

Output pipeline 1:

device cuda:0
INFO: [Torch-TensorRT] - ir was set to default, using TorchScript as ir
INFO: [Torch-TensorRT] - Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript
DEBUG: [Torch-TensorRT] - TensorRT Compile Spec: {
    "Inputs": [
Input(shape=(1,3,256,256,), dtype=Float, format=Contiguous/Linear/NCHW, tensor_domain=[0, 2))Input(shape=(1,3,256,256,), dtype=Float, format=Contiguous/Linear/NCHW, tensor_domain=[0, 2))    ]
    "Enabled Precision": [Half, ]
    "TF32 Disabled": 0
    "Sparsity": 0
    "Refit": 0
    "Debug": 0
    "Device":  {
        "device_type": GPU
        "allow_gpu_fallback": False
        "gpu_id": 0
        "dla_core": -1
    }

    "Engine Capability": Default
    "Num Avg Timing Iters": 1
    "Workspace Size": 0
    "DLA SRAM Size": 1048576
    "DLA Local DRAM Size": 1073741824
    "DLA Global DRAM Size": 536870912
    "Truncate long and double": 0
    "Allow Shape tensors": 0
    "Torch Fallback":  {
        "enabled": True
        "min_block_size": 3
        "forced_fallback_operators": [
        ]
        "forced_fallback_modules": [
        ]
    }
}
DEBUG: [Torch-TensorRT] - init_compile_spec with input vector
DEBUG: [Torch-TensorRT] - Settings requested for Lowering:
    torch_executed_modules: [
    ]
Traceback (most recent call last):
  File "/home/username/src/dev/deepWork/raft1.py", line 17, in <module>
    trt_ts_module = torch_tensorrt.compile(
  File "/home/username/bin/anaconda3/envs/dmm/lib/python3.9/site-packages/torch_tensorrt/_compile.py", line 133, in compile
    return torch_tensorrt.ts.compile(
  File "/home/username/bin/anaconda3/envs/dmm/lib/python3.9/site-packages/torch_tensorrt/ts/_compiler.py", line 139, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: 
temporary: the only valid use of a module is looking up an attribute but found  = prim::SetAttr[name="corr_pyramid"](%corr_block.1, %514)

pipeline2 : torch_tensorrt.ts.compile or torch_tensorrt.ts.TensorRTCompileSpec

import torch
import torch_tensorrt
from torchvision.models.optical_flow import Raft_Small_Weights
from torchvision.models.optical_flow import raft_small
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device", device)
with torch_tensorrt.logging.debug():
    model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=False).to(device)
    model = model.eval()
    print('model loaded ')
    # print(model)
    inputs=[
    torch.randn((1, 3, 256, 256), dtype=torch.float).to(device), # Use an example tensor and let torch_tensorrt infer settings
    torch.randn((1, 3, 256, 256), dtype=torch.float).to(device) # Use an example tensor and let torch_tensorrt infer settings
    ]
    model_scripted = torch.jit.trace(model, inputs)
    # model_scripted = torch.jit.script(model)
    print('torchscript conversion done !')
    model_scripted.to(device)
    print('Start torch-tensoRT compilation... ')
    # torch_tensorrt.ts.check_method_op_support(model_scripted)
    model_trt = torch_tensorrt.ts.compile(model_scripted, 
                              inputs=inputs,
                              input_signature=inputs, 
                              device=device, 
                              workspace_size=0,
                              # List of aten operators that must be run in PyTorch. 
                              torch_executed_ops=['corr_pyramid', 'corr_block.1', 
                                                  'corr_block','corr_volume', 
                                                  'normalized_grid', 'GridSample', 
                                                  'sub', 'aten::sub'],
                              # List of modules that must be run in PyTorch
                              torch_executed_modules=['corr_pyramid', 'corr_block.1', 
                                                  'corr_block','corr_volume', 
                                                  'normalized_grid', 'GridSample',
                                                  'corr_block'],
                              enabled_precisions={torch.half},
                              debug=True,
                              refit=False,
                              num_avg_timing_iters=1,
                              calibrator=None, 
                              truncate_long_and_double=False, 
                              require_full_compilation=False, 
                              min_block_size=3, 
                              allow_shape_tensors=False
                              )
    print('conversion done !')

Output pipeline 2:

device cuda:0
model loaded 
/home/username/bin/anaconda3/envs/dmm/lib/python3.9/site-packages/torchvision/models/optical_flow/raft.py:418: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return corr / torch.sqrt(torch.tensor(num_channels))
/home/username/bin/anaconda3/envs/dmm/lib/python3.9/site-packages/torch/jit/_trace.py:1056: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
  module._c._create_method_from_trace(
torchscript conversion done !
Start torch-tensoRT compilation... 
DEBUG: [Torch-TensorRT] - TensorRT Compile Spec: {
    "Inputs": [
Input(shape=(1,3,256,256,), dtype=Float, format=Contiguous/Linear/NCHW, tensor_domain=[0, 2))Input(shape=(1,3,256,256,), dtype=Float, format=Contiguous/Linear/NCHW, tensor_domain=[0, 2))    ]
    "Enabled Precision": [Half, ]
    "TF32 Disabled": 0
    "Sparsity": 0
    "Refit": 0
    "Debug": 1
    "Device":  {
        "device_type": GPU
        "allow_gpu_fallback": False
        "gpu_id": 0
        "dla_core": -1
    }

    "Engine Capability": Default
    "Num Avg Timing Iters": 1
    "Workspace Size": 0
    "DLA SRAM Size": 1048576
    "DLA Local DRAM Size": 1073741824
    "DLA Global DRAM Size": 536870912
    "Truncate long and double": 0
    "Allow Shape tensors": 0
    "Torch Fallback":  {
        "enabled": True
        "min_block_size": 3
        "forced_fallback_operators": [
            corr_pyramid,
            corr_block.1,
            corr_block,
            corr_volume,
            normalized_grid,
            GridSample,
            sub,
            aten::sub,
        ]
        "forced_fallback_modules": [
            corr_pyramid,
            corr_block.1,
            corr_block,
            corr_volume,
            normalized_grid,
            GridSample,
            corr_block,
        ]
    }
}
DEBUG: [Torch-TensorRT] - init_compile_spec with input vector
DEBUG: [Torch-TensorRT] - Settings requested for Lowering:
    torch_executed_modules: [
      corr_pyramid
      corr_block.1
      corr_block
      corr_volume
      normalized_grid
      GridSample
      corr_block
    ]
DEBUG: [Torch-TensorRT] - RemoveNOPs - Note: Removing operators that have no meaning in TRT
Traceback (most recent call last):
  File "/home/username/src/dev/deepWork/raft2.py", line 22, in <module>
    model_trt = torch_tensorrt.ts.compile(model_scripted, 
  File "/home/username/bin/anaconda3/envs/dmm/lib/python3.9/site-packages/torch_tensorrt/ts/_compiler.py", line 139, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: 
Schema not found for node. File a bug report.
Node: %18222 : int = aten::sub(%15973, %18221, %24)

Input types:int, int, int
candidates were:
  aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
  aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
  aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)
  aten::sub.int(int a, int b) -> int
  aten::sub.complex(complex a, complex b) -> complex
  aten::sub.float(float a, float b) -> float
  aten::sub.int_complex(int a, complex b) -> complex
  aten::sub.complex_int(complex a, int b) -> complex
  aten::sub.float_complex(float a, complex b) -> complex
  aten::sub.complex_float(complex a, float b) -> complex
  aten::sub.int_float(int a, float b) -> float
  aten::sub.float_int(float a, int b) -> float
  aten::sub(Scalar a, Scalar b) -> Scalar
within the graph:
graph(%image1 : Tensor,
      %image2 : Tensor):
  %self.feature_encoder.convnormrelu.0.bias.1 : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  %self.feature_encoder.convnormrelu.0.weight.1 : Float(32, 3, 7, 7, strides=[147, 49, 7, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
  ...
  ... # a lot of torchscript debug lines ...
  ...
    %1797 : Tensor = aten::mul(%18167, %12) # /home/username/bin/anaconda3/envs/dmm/lib/python3.9/site-packages/torchvision/models/optical_flow/_utils.py:40:0
  %1798 : Tensor[] = prim::ListConstruct(%477, %597, %717, %837, %957, %1077, %1197, %1317, %1437, %1557, %1677, %1797)
  return (%1798)

:

pipeline 3 : pth → onnx → trtexec:

  • onnx export work as expected
  • but trtexec convertion fail
# [07/04/2023-18:06:19] [V] [TRT] Parsing node: GridSample_329 [GridSample]
# [07/04/2023-18:06:19] [V] [TRT] Searching for input: corr_volume
# [07/04/2023-18:06:19] [V] [TRT] Searching for input: normalized_grid
# [07/04/2023-18:06:19] [V] [TRT] GridSample_329 [GridSample] inputs: [corr_volume -> (1024, 1, 32, 32)[FLOAT]], [normalized_grid -> (1024, 7, 7, 2)[FLOAT]], 
# [07/04/2023-18:06:19] [I] [TRT] No importer registered for op: GridSample. Attempting to import as plugin.
# [07/04/2023-18:06:19] [I] [TRT] Searching for plugin: GridSample, plugin_version: 1, plugin_namespace: 
# [07/04/2023-18:06:19] [E] [TRT] 3: getPluginCreator could not find plugin: GridSample version: 1
# [07/04/2023-18:06:19] [E] [TRT] ModelImporter.cpp:720: While parsing node number 329 [GridSample -> "onnx::Reshape_511"]:
# [07/04/2023-18:06:19] [E] [TRT] ModelImporter.cpp:721: --- Begin node ---
# [07/04/2023-18:06:19] [E] [TRT] ModelImporter.cpp:722: input: "corr_volume"
# input: "normalized_grid"
# output: "onnx::Reshape_511"
# name: "GridSample_329"
# op_type: "GridSample"
# attribute {
#   name: "align_corners"
#   i: 1
#   type: INT
# }
# attribute {
#   name: "mode"
#   s: "bilinear"
#   type: STRING
# }
# attribute {
#   name: "padding_mode"
#   s: "zeros"
#   type: STRING
# }


# [07/04/2023-18:06:19] [E] [TRT] ModelImporter.cpp:723: --- End node ---
# [07/04/2023-18:06:19] [E] [TRT] ModelImporter.cpp:726: ERROR: builtin_op_importers.cpp:4643 In function importFallbackPluginImporter:
# [8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
# [07/04/2023-18:06:19] [E] Failed to parse onnx file
# [07/04/2023-18:06:19] [I] Finish parsing network model
# [07/04/2023-18:06:19] [E] Parsing model failed
# [07/04/2023-18:06:19] [E] Engine creation failed
# [07/04/2023-18:06:19] [E] Engine set up failed
# &&&& FAILED TensorRT.trtexec [TensorRT v8003] # /usr/src/tensorrt/bin/trtexec --onnx=/home/username/.cache/torch/hub/checkpoints/raft_small_C_T_V2-01064c6d.onnx --saveEngine=/home/username/.cache/torch/hub/checkpoints/raft_small_C_T_V2-01064c6d.trt --explicitBatch --fp16 --workspace=1024 --verbose

Could you help and provide a documented explication of how to convert this RAFT network ? I successfully convert VGG and others simple backbones but this net is hard to debug

PS: I try also to add ir=“fx” to use the fx frontend as suggested in this post Unable to compile model in torch_tensorrt but without success

CC @narendasan for TorchTRT

Does anyone have any suggestions to resolve this issue?