I have a nn that takes two arguments: first one is a tensor, second is a list of variable length. I’m able to dump the nn as onnx with torch.onnx.export(). However, the dynamic_axes argument doesn’t work.
class ActorNetwork(nn.Module):
# actornetwork pass the test
def __init__(self,state_dim,action_dim,n_conv=128,n_fc=128,n_fc1=128):
super(ActorNetwork,self).__init__(state_dim,action_dim,n_conv=n_conv,n_fc=n_fc,n_fc1=n_fc1)
self.numFcInput=2 * self.vectorOutDim * (self.s_dim[1]-4+2) + 3 * self.scalarOutDim
self.fullyConnected=nn.Linear(self.numFcInput,self.numFcOutput)
# Increase the pen-ultimate layer by one (for available bitrate)
self.outputLayer=nn.Linear(self.numFcOutput + 1,1)
self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward_(self, inputs, candidate_bitrate):
bitrateFcOut=F.relu(self.bitrateFc(inputs[:,0:1,-1]),inplace=True)
# print(inputs[:,0:1,-1].shape,bitrateFcOut.shape)
# torch.Size([1, 1]) torch.Size([1, 128])
# torch.Size([47, 1]) torch.Size([47, 128])
bufferFcOut=F.relu(self.bufferFc(inputs[:,1:2,-1]),inplace=True)
tConv1dOut=F.relu(self.tConv1d(inputs[:,2:3,:]),inplace=True)
dConv1dOut=F.relu(self.dConv1d(inputs[:,3:4,:]),inplace=True)
cConv1dOut=F.relu(self.cConv1d(inputs[:,4:5,:]),inplace=True)
# print(cConv1dOut.shape) #[1, 128, 3]
leftChunkFcOut=F.relu(self.leftChunkFc(inputs[:,5:6,-1]),inplace=True)
t_flatten=tConv1dOut.view(tConv1dOut.shape[0],-1)
d_flatten=dConv1dOut.view(dConv1dOut.shape[0],-1)
c_flatten=cConv1dOut.view(dConv1dOut.shape[0],-1)
fullyConnectedInput=torch.cat([bitrateFcOut,bufferFcOut,t_flatten,d_flatten, c_flatten,leftChunkFcOut],1)
fcOutput=F.relu(self.fullyConnected(fullyConnectedInput),inplace=True)
# Add dynamic bitrate levels
# fcOutput.shape[0]==1
# print(fcOutput.shape) # (1,128)
out = torch.cat([fcOutput, torch.ones_like(torch.empty(fcOutput.shape[0],1)).to(self.device)*candidate_bitrate],1)
out = self.outputLayer(out)
return out
def forward(self, inputs, available_bitrates):
# Enable dynamic bitrate levels: policy nn evaluates each available bitrate; convert from kbps to mbps
if hasattr(available_bitrates,'to_list'): available_bitrates = available_bitrates.tolist()[0,0]
print('available_bitrates: ',available_bitrates[0])
out = torch.cat([self.forward_(inputs.to(self.device),i/1000) for i in available_bitrates[0]],1)
out = torch.softmax(out, dim=-1)
print(out)
return out
For example, even torch.onnx.export() is dumping the onnx file. Reloading the onnx with onnxruntime still gives an error saying size of the second argument doesn’t fit:
net=ActorNetwork()
input_names = [ "input_state", "input_bit_rates"]
output_names = [ "output" ]
dummy_input = (state, torch.zeros(1,50))
seq = 6
dynamic_axes = {'input_state':[0],'input_bit_rates': [0,1], 'output':[0,1]}
torch.onnx.export(net,
args = dummy_input,
f = "results/tmp.onnx",
verbose=False,
input_names=input_names,
output_names=output_names,
export_params=True,
dynamic_axes = dynamic_axes
)
import onnxruntime as rt
sess = rt.InferenceSession('results/tmp.onnx')
# input_names = [ "input_state", "input_bit_rates" ]
dummy_input = (state, None)
sess.run(['output'],{'input_state': state
.numpy(), "input_bit_rates": [[0]*10] })
The error I got was:
2022-04-19 03:14:54.639048822 [E:onnxruntime:, sequential_executor.cc:364 Execute] Non-zero status code returned while running Split node. Name:'Split_2' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={10} NumOutputs=50 Num entries in 'split' (must equal number of outputs) was 50 Sum of sizes in 'split' (must equal size of selected axis) was 50
Traceback (most recent call last):
File "tmp.py", line 149, in <module>
.numpy(), "input_bit_rates": [[0]*10] })
File "/home/ec2-user/anaconda3/envs/mxnet_latest_p37/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 192, in run
return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Split node. Name:'Split_2' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={10} NumOutputs=50 Num entries in 'split' (must equal number of outputs) was 50 Sum of sizes in 'split' (must equal size of selected axis) was 50
Does this mean that dynamic_axes won’t work when torch.cat() is present? Any workaround? Thanks.