ONNX fixes input/output shape of torch.jit.script function?

@torch.jit.script
def check_init(input_data, hidden_size, prev_state):
    # type: (torch.Tensor, int, torch.Tensor) -> torch.Tensor
    batch_size = input_data.size(0)
    spatial_size_0 = input_data.size(2)
    spatial_size_1 = input_data.size(3)
    # generate empty prev_state, if None is provided
    state_size = (2, batch_size, hidden_size ,spatial_size_0, spatial_size_1)
    if prev_state.size(0) == 0:
        state = torch.zeros(state_size, device=input_data.device)
    else:
        state = prev_state.view(state_size)
    return state

I am trying to export my model to ONNX and I have a function that will check if the previous state is initialized and I will initialize it based on the input size. Because I have an if statement I decorated the function with @torch.jit.script. prev_state is a tensor of dimension 5. At first I pass an empty tensor like this torch.tensor([]).view(0,0,0,0,0). In all subsequent runs, I will pass back the returned tensor. This check_init function is used at 3 different places in the network and the input_data variable is the output of one of the stages of the neural net. For the original input of the full neural net, I have set the input and output to have dynamic_axes.

The model is properly exported to a .onnx file. However, when I run it with an input, I get the following error at the second iteration (the first iteration when prev_state is of size (0,0,0,0,0) works fine):

2020-06-04 13:32:14.289010608 [E:onnxruntime:, sequential_executor.cc:281 Execute] Non-zero status code returned while running Identity node. Name:'Identity_29' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:66 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, const onnxruntime::TensorShape*, OrtValue*&, size_t) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{0,0,0,0,0} Requested shape:{2,1,64,92,120}

2020-06-04 13:32:14.289043204 [E:onnxruntime:, sequential_executor.cc:281 Execute] Non-zero status code returned while running If node. Name:'If_21' Status Message: Non-zero status code returned while running Identity node. Name:'Identity_29' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:66 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, const onnxruntime::TensorShape*, OrtValue*&, size_t) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{0,0,0,0,0} Requested shape:{2,1,64,92,120}

Traceback (most recent call last):
  File "run_onnx.py", line 108, in <module>
    runner.update(input_tensor, last_timestamp)
  File "/home/test/image_onnx.py", line 103, in update
    onnx_out = self.onnx_session.run(None, onnx_input)
  File "/home/test/lib/python3.8/site-packages/onnxruntime/capi/session.py", line 111, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running If node. Name:'If_21' Status Message: Non-zero status code returned while running Identity node. Name:'Identity_29' Status Message: /onnxruntime_src/onnxruntime/core/framework/execution_frame.cc:66 onnxruntime::common::Status onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(int, const onnxruntime::TensorShape*, OrtValue*&, size_t) shape && tensor.Shape() == *shape was false. OrtValue shape verification failed. Current shape:{0,0,0,0,0} Requested shape:{2,1,64,92,120}

Shouldn’t the exporter ‘propagate’ the dynamic axis property to all subsequent stages of the network and realize that those variables will also have dynamic size? Any ideas on how to solve this problem?

Exporting using just TorchScript works fine but it seems ONNX is less flexible

Hi @Andreas_Georgiou,

This error occurs within ONNX Runtime, so it’s likely the case that you should report an issue there, and then work backwards up the stack. It’s not clear if the issue is within PyTorch ONNX export, or if the ONNX exporter is emitting a valid ONNX model and it’s a failed analysis within ONNX runtime.