Get error "tuple index with non-constant index" when trying to export a model to onnx format

I’m trying to convert a model to onnx format. A function was wrapped by torch.jit.script() in the model forward method. No other torch.jit.script() was uesd. The forward method is following:

    def forward(self, img):
        img = self.onnx_trans(img)
        x = self.backbone(img)
        x = self.neck(x)
        mask_feat_pred = self.mask_feat_head(x[self.mask_feat_head.start_level:self.mask_feat_head.end_level + 1])
        cate_preds, kernel_preds = self.bbox_head(x)
        get_seg = torch.jit.script(get_seg_scripted)
        seg_result = get_seg(cate_preds, kernel_preds, mask_feat_pred)
        print(seg_result)
        return seg_result

And I use the following code to export the model:

    cfg = Custom_light_res50(mode='detect')
    cfg.print_cfg()
    model = SOLOv2(cfg).cuda()
    model.load_state_dict(torch.load(cfg.val_weight), strict=True)
    model.eval()

    input_size = (512, 512)
    input_img = cv2.imread('detect_imgs/test1.bmp', cv2.IMREAD_COLOR)
    input_img = cv2.resize(input_img, input_size)
    input_img = torch.from_numpy(input_img).cuda()

    torch.onnx.export(model,
                      input_img,
                      'seg.onnx',
                      input_names=['seg'],
                      output_names=['output'],
                      verbose=False,
                      opset_version=14)

I can get the correct value of variable seg_result. But still I get the error:

Traceback (most recent call last):
  File "/home/feiyu/SOLOv2_minimal/ttt.py", line 99, in <module>
    torch.onnx.export(model,
  File "/home/feiyu/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 504, in export
    _export(
  File "/home/feiyu/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1529, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/feiyu/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1115, in _model_to_graph
    graph = _optimize_graph(
  File "/home/feiyu/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 582, in _optimize_graph
    _C._jit_pass_lower_all_tuples(graph)
RuntimeError: tuple index with non-constant index

Did I do anything wrong? Since I can already get the correct value, which means the computation is correct. Why there’s still an error?

1 Like

Hi, did you get any solution to this error?
I am getting the same issue.