Hi,
I am having issues exporting a pytorch model to onnx via torch.onnx.dynamo_export.
I created a small code example for reproduction of the issue and have the following questions:
- Why leads the expression in the forward pass to the shown error when exporting?
- What would be the optimal solution to avoid this error?
This is my first time exporting a model with dynamo_export and i would be happy for any explanations.
Thanks in advance!
Issue description
The export fails because of this expression in the forward pass of the model:
tensor_1[:, indices] = tensor_2
Code example
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, indices:torch.Tensor, tensor_1:torch.Tensor, tensor_2:torch.Tensor):
tensor_1[:, indices] = tensor_2
return tensor_1
# create model
model = MyModel()
model.cuda()
model.eval()
# create inputs
tensor_1 = torch.rand((32, 100), dtype=torch.float32, device="cuda:0")
tensor_2 = torch.rand((32, 50), dtype=torch.float32, device="cuda:0")
indices = torch.randint(low=0, high=100, size=(50,), dtype=torch.long, device="cuda:0")
# export model
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_model = torch.onnx.dynamo_export(model, indices, tensor_1, tensor_2, export_options=export_options)
Error Message
Traceback (most recent call last):
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/exporter.py", line 1428, in dynamo_export
return Exporter(
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/exporter.py", line 1186, in export
onnxscript_graph = fx_interpreter.run(
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py", line 151, in wrapper
ctx.log_and_raise_if_error(diag)
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/context.py", line 366, in log_and_raise_if_error
raise diagnostic.source_exception
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py", line 135, in wrapper
return_values = fn(*args, **kwargs)
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/fx/fx_onnx_interpreter.py", line 539, in run
self.run_node(
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py", line 151, in wrapper
ctx.log_and_raise_if_error(diag)
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/context.py", line 366, in log_and_raise_if_error
raise diagnostic.source_exception
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/diagnostics/infra/decorator.py", line 135, in wrapper
return_values = fn(*args, **kwargs)
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/fx/fx_onnx_interpreter.py", line 433, in run_node
self.call_function(
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/fx/fx_onnx_interpreter.py", line 665, in call_function
] = symbolic_fn(*onnx_args, **onnx_kwargs)
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/onnxscript/values.py", line 529, in __call__
return evaluator.default().eval_function(self, args, kwargs)
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/onnxscript/function_libs/torch_lib/graph_building.py", line 392, in eval_function
return self._graph.add_function_call(function, inputs, attributes)
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/onnxscript/function_libs/torch_lib/graph_building.py", line 841, in add_function_call
result = self._add_torchscript_op_call(
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/onnxscript/function_libs/torch_lib/graph_building.py", line 705, in _add_torchscript_op_call
graph_inputs.append(self._add_constant_to_graph(input))
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/onnxscript/function_libs/torch_lib/graph_building.py", line 665, in _add_constant_to_graph
raise TypeError(
TypeError: Constant input '[None, l_indices_ defined in (%l_indices_ : Long(*), %l_tensor_1_ : Float(*, *), %l_tensor_2_ : Float(*, *) = prim::Param()
)]' of type '<class 'list'>' is not supported
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "convert_to_onnx_test_script.py", line 29, in <module>
onnx_model = torch.onnx.dynamo_export(model, indices, tensor_1, tensor_2, export_options=export_options)
File "/home/labor/anaconda3/envs/view-of-delft-env/lib/python3.8/site-packages/torch/onnx/_internal/exporter.py", line 1444, in dynamo_export
raise OnnxExporterError(
torch.onnx.OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'.