Dynamic_axes not working as expected in torch.oxxn.export (torchvision)

Hi all,
I am not sure what is going on but dynamic_axes is not working as expected when re-loading the model from ONNX and running inference, e.g. I am obliged to feed the exact same input shape provided at saving time, which is NOT what occurs here.
Here what happens.

import torchvision
import torch
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
x = torch.rand(1, 3, 300, 400)
predictions = model(x)
torch.onnx.export(model,              
                  x,                         
                  "./output/faster_rcnn.onnx",   
                  export_params=True,        
                  opset_version=11,     
                  input_names = ['input'],   
                  output_names = ['output'], 
                  dynamic_axes={'input' : {0 : 'batch_size'},    
                                'output' : {0 : 'batch_size'}})


import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession("./output/faster_rcnn.onnx")
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
y = torch.rand(10, 3, 300, 400)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(y)}
ort_outs = ort_session.run(None, ort_inputs)

Throws

---------------------------------------------------------------------------
Fail                                      Traceback (most recent call last)
<ipython-input-9-ded63d7ba257> in <module>
     11 # compute ONNX Runtime output prediction
     12 ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(y)}
---> 13 ort_outs = ort_session.run(None, ort_inputs)
     14 ort_outs

/anaconda/envs/fra_vrr/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py in run(self, output_names, input_feed, run_options)
    122             output_names = [output.name for output in self._outputs_meta]
    123         try:
--> 124             return self._sess.run(output_names, input_feed, run_options)
    125         except C.EPFail as err:
    126             if self._enable_fallback:

Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Split node. Name:'Split_8' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={10,3,300,400} NumOutputs=1 Num entries in 'split' (must equal number of outputs) was 1 Sum of sizes in 'split' (must equal size of selected axis) was 1

To be clear, everything works fine when y = torch.rand(1, 3, 300, 400).
Is this a torchvision issue?
Thanks!

1 Like