When using this code, I get an error:
input = torch_tensorrt.Input(
min_shape=(1, 3, config.model_image_size, config.model_image_size),
opt_shape=(16, 3, config.model_image_size, config.model_image_size),
max_shape=(16, 3, config.model_image_size, config.model_image_size),
dtype=torch.half, name="x")
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=[input], enabled_precisions = {torch.half, torch.float}, output_format="torchscript")
torch.jit.save(trt_gm, "trt_model.ts")
File "/mnt/c/Coding/Testing/PyTorch/MultiClassImageClassification/src/imclaslib/models/multilabel_classifier.py", line 22, in forward
image_features = self.base_model(x) # [batch_size, feature_dim]
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/timm/models/maxxvit.py", line 1264, in forward
x = self.forward_features(x)
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/timm/models/maxxvit.py", line 1255, in forward_features
x = self.stem(x)
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/timm/models/maxxvit.py", line 1098, in forward
x = self.norm1(x)
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/timm/layers/norm_act.py", line 118, in forward
x = F.batch_norm(
However, I don’t get the same error if I pass in a real example batch of images instead of torch_tensorrt.Input:
images = images.half()
model = model.half()
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=[images], enabled_precisions = {torch.half, torch.float}, output_format="torchscript")
I enabled debug logging with torch_tensorrt and got the root error message though I’m still not sure what to do to get around this:
Traceback (most recent call last):
File "/mnt/c/Coding/Testing/PyTorch/MultiClassImageClassification/src/compressmodel.py", line 48, in <module>
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=[input], enabled_precisions = {torch.half, torch.float}, output_format="torchscript")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 228, in compile
trt_graph_module = dynamo_compile(
^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 236, in compile
trt_gm = compile_module(gm, inputs, settings)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/_compiler.py", line 346, in compile_module
trt_module = convert_module(
^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 56, in convert_module
interpreter_result = interpreter.run()
^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 152, in run
super().run()
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch/fx/interpreter.py", line 138, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 276, in run_node
trt_node: torch.fx.Node = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch/fx/interpreter.py", line 195, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 362, in call_function
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 103, in aten_ops_batch_norm_legit_no_training
return impl.normalization.batch_norm(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/multilabelimage_model_env/lib/python3.11/site-packages/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py", line 60, in batch_norm
if not ctx.net.has_implicit_batch_dimension and len(input.shape) < 4:
^^^^^^^^^^^^^^^^
ValueError: __len__() should return >= 0