Hello everyone, I’m new to ONNX and I’m trying to convert a model where I need do some for-loop assignmens like the code below,
import torch
import torch.nn as nn
@torch.jit.script
def create_alignment_v2():
base_mat = torch.zeros(2, 2, 2)
for i in range(base_mat.size(0)):
base_mat[i][0][0] = 1
return base_mat
class ToyModule(nn.Module):
def __int__(self):
super().__init__()
def forward(self, duration_predictor_output):
alignment = create_alignment_v2()
# output = alignment @ x
return alignment
def test():
module = ToyModule()
module.eval()
x = torch.rand(2, 28, 384)
alignment = module(x)
torch.onnx.export(module, x, 'toy.onnx',
export_params=True,
opset_version=10,
do_constant_folding=True,
verbose=True,
input_names=['seq'],
output_names=['alignment'],
dynamic_axes={'seq': {0: 'batch', 1: 'sequence'},}
)
test()
And the error message is:
Traceback (most recent call last):
File "/mfs/fangzhiqiang/workspace/tts/FastSpeech/jit_script.py", line 83, in <module>
test()
File "/mfs/fangzhiqiang/workspace/tts/FastSpeech/jit_script.py", line 78, in test
dynamic_axes={'seq': {0: 'batch', 1: 'sequence'},}
File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/__init__.py", line 168, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/utils.py", line 69, in export
use_external_data_format=use_external_data_format)
File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/utils.py", line 488, in _export
fixed_batch_size=fixed_batch_size)
File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/utils.py", line 351, in _model_to_graph
fixed_batch_size=fixed_batch_size, params_dict=params_dict)
File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/utils.py", line 120, in _optimize_graph
torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx(graph)
IndexError: vector::_M_range_check: __n (which is 2) >= this->size() (which is 2)
In fact, if I remove the assignment operation, the graph can be built succesfully. I wonder whether this is a bug and how to convert such model to ONNX. Thanks for any reply~