I have a model which has a dictionary of torch tensor as inputs, E.g. with one element
export_input=({"head_1": torch.ones(2, 4)})
with torch 2.3 this could be exported as:
onnx_program = torch.onnx.dynamo_export(model_to_export, *export_input)
Now with torch 2.7, I am using the torch.onnx.export api. There is two things which dont work as expected:
- I have to use a sort of workaround for the input, since the API hardcodes non-tensor inputs into the model (not sure what that means). But if I just give it the input directly, it complains that the positional argument model_input_dict is not given. even though this should exactly be
*export_input.
Hower I can solve it like this:
torch.onnx.export(
model_to_export,
args = (),
kwargs={"model_input_dict": export_input[0]},
f=model_path,
dynamo=True)
The export works, except for one line, where I put the up to here dynamic output into a static output:
n_max = 1000 # fixed output size
# class_all_ is of a dynamic size
class_out = torch.zeros(
n_max, dtype=class_all_.dtype, device=class_all_.device
)
class_out.index_copy_(0, indices, class_all_)
This worked fined in torch 2.3. Is this a bug or do I have to provide some other options?
I get the following error which I cannot resolve:
<class 'TypeError'>: unsupported operand type(s) for *: 'int' and 'SymbolicDim'
⬆️
<class 'torch.onnx._internal.exporter._errors.GraphConstructionError'>: Error when calling function 'TracedOnnxFunction(<function aten_index_put at 0x7575b2cee830>)' with args '[SymbolicTensor(name='zeros', type=Tensor(FLOAT), shape=Shape([1000, 4]), producer='node_Expand_267', index=0), [SymbolicTensor(name='arange_2', type=Tensor(INT64), shape=Shape([SymbolicDim(u0)]), producer='node_Range_279', index=0)], SymbolicTensor(name='detach_2', type=Tensor(FLOAT), shape=Shape([SymbolicDim(u0), 4]), producer='node_Identity_254', index=0)]' and kwargs '{}'
⬆️
<class 'torch.onnx._internal.exporter._errors.ConversionError'>: Error when translating node %index_put : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%zeros, [%arange_2], %detach_2), kwargs = {}). See the stack trace for more information.
In this error does the “*” stand for a multiplication? But the problem really just occurs in that index line