Tutorial for torch.onnx fails

I use the following code to reproduce the tutorial example avoid-numpy-and-built-in-python-types

class VarModel(nn.Module):    
    def forward(self, x, y):
        return x.reshape(y, -1)
    
var_model = VarModel()
a = torch.arange(6)
b = torch.tensor([3])
c = torch.tensor([2])
var_model(a, b), var_model(a, c)

torch.onnx.export(var_model, (a, b), "var_op.onnx")

ort_session = ort.InferenceSession("var_op.onnx")
out1 = ort_session.run(None, {"0": a.numpy(), "1": b.numpy()})
out2 = ort_session.run(None, {"0": a.numpy(), "1": c.numpy()})
print(out1)
print(out2)

But I got this error:

Fail                                      Traceback (most recent call last)
/var/folders/3p/mvrxlkq16dlfhgxbl9f31bvh0000gn/T/ipykernel_1802/3457345533.py in <module>
----> 1 ort_session = ort.InferenceSession("var_op.onnx")
      2 out1 = ort_session.run(None, {"0": a.numpy(), "1": b.numpy()})
      3 out2 = ort_session.run(None, {"0": a.numpy(), "1": c.numpy()})
      4 out1, out2

~/miniconda3/envs/py37/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py in __init__(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)
    358 
    359         try:
--> 360             self._create_inference_session(providers, provider_options, disabled_optimizers)
    361         except ValueError:
    362             if self._enable_fallback:

~/miniconda3/envs/py37/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py in _create_inference_session(self, providers, provider_options, disabled_optimizers)
    395         session_options = self._sess_options if self._sess_options else C.get_default_session_options()
    396         if self._model_path:
--> 397             sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
    398         else:
    399             sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)

Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from var_op.onnx failed:Node (Concat_1) Op (Concat) [ShapeInferenceError] All inputs to Concat must have same rank

The version is:

  • torch 1.8.1
  • onnx 1.13.1
  • onnxruntime 1.14.1

Could you update PyTorch to the latest stable or nightly release and check if you would still run into the error?

update to torch 1.13.1 solve this problem. I found the onnx model is different between 1.8.1 and 1.13.1 that the former with an additional unsqueeze op:

It seems like that if the unsqueeze op is removed by onnx helper, the model produced by 1.8.1 will be also okay.