Error when exporting ScriptFuntion to ONNX

Hello everyone, I’m new to ONNX and I’m trying to convert a model where I need do some for-loop assignmens like the code below,

import torch
import torch.nn as nn

@torch.jit.script
def create_alignment_v2():
    base_mat = torch.zeros(2, 2, 2)

    for i in range(base_mat.size(0)):
        base_mat[i][0][0] = 1

    return base_mat

class ToyModule(nn.Module):

    def __int__(self):
        super().__init__()


    def forward(self, duration_predictor_output):

        alignment = create_alignment_v2()
        # output = alignment @ x
        return alignment

def test():
    module = ToyModule()
    module.eval()
    x = torch.rand(2, 28, 384)
    alignment = module(x)
    torch.onnx.export(module, x, 'toy.onnx',
                      export_params=True,
                      opset_version=10,
                      do_constant_folding=True,
                      verbose=True,
                      input_names=['seq'],
                      output_names=['alignment'],
                      dynamic_axes={'seq': {0: 'batch', 1: 'sequence'},}
                      )

test()

And the error message is:

Traceback (most recent call last):
  File "/mfs/fangzhiqiang/workspace/tts/FastSpeech/jit_script.py", line 83, in <module>
    test()
  File "/mfs/fangzhiqiang/workspace/tts/FastSpeech/jit_script.py", line 78, in test
    dynamic_axes={'seq': {0: 'batch', 1: 'sequence'},}
  File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/__init__.py", line 168, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/utils.py", line 69, in export
    use_external_data_format=use_external_data_format)
  File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/utils.py", line 488, in _export
    fixed_batch_size=fixed_batch_size)
  File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/utils.py", line 351, in _model_to_graph
    fixed_batch_size=fixed_batch_size, params_dict=params_dict)
  File "/home/fangzhiqiang/miniconda3/envs/tts/lib/python3.6/site-packages/torch/onnx/utils.py", line 120, in _optimize_graph
    torch._C._jit_pass_onnx_prepare_inplace_ops_for_onnx(graph)
IndexError: vector::_M_range_check: __n (which is 2) >= this->size() (which is 2)

In fact, if I remove the assignment operation, the graph can be built succesfully. I wonder whether this is a bug and how to convert such model to ONNX. Thanks for any reply~