Dynamic_axes doesn't work for torch.onnx.export() when torch.cat is present?

I have a nn that takes two arguments: first one is a tensor, second is a list of variable length. I’m able to dump the nn as onnx with torch.onnx.export(). However, the dynamic_axes argument doesn’t work.

class ActorNetwork(nn.Module):
    # actornetwork pass the test
    def __init__(self,state_dim,action_dim,n_conv=128,n_fc=128,n_fc1=128):
        super(ActorNetwork,self).__init__(state_dim,action_dim,n_conv=n_conv,n_fc=n_fc,n_fc1=n_fc1)
        self.numFcInput=2 * self.vectorOutDim * (self.s_dim[1]-4+2) + 3 * self.scalarOutDim
        self.fullyConnected=nn.Linear(self.numFcInput,self.numFcOutput)
        
        # Increase the pen-ultimate layer by one (for available bitrate)
        self.outputLayer=nn.Linear(self.numFcOutput + 1,1)
        
        self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    def forward_(self, inputs, candidate_bitrate):
        bitrateFcOut=F.relu(self.bitrateFc(inputs[:,0:1,-1]),inplace=True)
        # print(inputs[:,0:1,-1].shape,bitrateFcOut.shape)
        # torch.Size([1, 1]) torch.Size([1, 128])
        # torch.Size([47, 1]) torch.Size([47, 128])
        bufferFcOut=F.relu(self.bufferFc(inputs[:,1:2,-1]),inplace=True)
 
        tConv1dOut=F.relu(self.tConv1d(inputs[:,2:3,:]),inplace=True)

        dConv1dOut=F.relu(self.dConv1d(inputs[:,3:4,:]),inplace=True)

        cConv1dOut=F.relu(self.cConv1d(inputs[:,4:5,:]),inplace=True)
        
        # print(cConv1dOut.shape) #[1, 128, 3]
        leftChunkFcOut=F.relu(self.leftChunkFc(inputs[:,5:6,-1]),inplace=True)

        t_flatten=tConv1dOut.view(tConv1dOut.shape[0],-1)

        d_flatten=dConv1dOut.view(dConv1dOut.shape[0],-1)

        c_flatten=cConv1dOut.view(dConv1dOut.shape[0],-1)

        fullyConnectedInput=torch.cat([bitrateFcOut,bufferFcOut,t_flatten,d_flatten, c_flatten,leftChunkFcOut],1)
        
        fcOutput=F.relu(self.fullyConnected(fullyConnectedInput),inplace=True)
        # Add dynamic bitrate levels
        # fcOutput.shape[0]==1
        # print(fcOutput.shape) # (1,128)
        out = torch.cat([fcOutput, torch.ones_like(torch.empty(fcOutput.shape[0],1)).to(self.device)*candidate_bitrate],1)
        
        out = self.outputLayer(out)
        
        return out
    
    def forward(self, inputs, available_bitrates):
        # Enable dynamic bitrate levels: policy nn evaluates each available bitrate; convert from kbps to mbps
        if hasattr(available_bitrates,'to_list'): available_bitrates = available_bitrates.tolist()[0,0]
        print('available_bitrates: ',available_bitrates[0])
        out = torch.cat([self.forward_(inputs.to(self.device),i/1000) for i in available_bitrates[0]],1)
        out = torch.softmax(out, dim=-1)

        print(out)
        return out

For example, even torch.onnx.export() is dumping the onnx file. Reloading the onnx with onnxruntime still gives an error saying size of the second argument doesn’t fit:


net=ActorNetwork()
input_names = [ "input_state", "input_bit_rates"]
output_names = [ "output" ]
dummy_input = (state, torch.zeros(1,50))

seq = 6
dynamic_axes = {'input_state':[0],'input_bit_rates': [0,1], 'output':[0,1]}
torch.onnx.export(net, 
      args = dummy_input,
      f = "results/tmp.onnx",
      verbose=False,
      input_names=input_names,
      output_names=output_names,
      export_params=True,
      dynamic_axes = dynamic_axes
                 )



import onnxruntime as rt
sess = rt.InferenceSession('results/tmp.onnx')
# input_names = [ "input_state", "input_bit_rates" ]


dummy_input = (state, None)

sess.run(['output'],{'input_state': state
                       .numpy(), "input_bit_rates": [[0]*10] })

The error I got was:

2022-04-19 03:14:54.639048822 [E:onnxruntime:, sequential_executor.cc:364 Execute] Non-zero status code returned while running Split node. Name:'Split_2' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={10} NumOutputs=50 Num entries in 'split' (must equal number of outputs) was 50 Sum of sizes in 'split' (must equal size of selected axis) was 50
Traceback (most recent call last):
  File "tmp.py", line 149, in <module>
    .numpy(), "input_bit_rates": [[0]*10] })
  File "/home/ec2-user/anaconda3/envs/mxnet_latest_p37/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 192, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Split node. Name:'Split_2' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={10} NumOutputs=50 Num entries in 'split' (must equal number of outputs) was 50 Sum of sizes in 'split' (must equal size of selected axis) was 50

Does this mean that dynamic_axes won’t work when torch.cat() is present? Any workaround? Thanks.