I simplify my complex Pytoch model like belows.
import torch
from torch import nn
import onnx
import onnxruntime
import numpy as np
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.template = torch.randn((1000, 1000))
def forward(self, points):
template = self.template
points = points.reshape(-1, 2)
heatmaps = [template[point[0]:point[0] + 10, point[1]:point[1] + 20] for point in points]
return heatmaps
model = Model()
points = torch.randint(100, 200, (1, 8, 2))
torch.onnx.export(model, args=points, f='toy.onnx',
export_params=True,
opset_version=13,
do_constant_folding=True,
verbose=False,
input_names=['input1'],
output_names=['output1'],
dynamic_axes={'input1': {0: 'batch_size'},
'output1': {0: 'batch_size'},
}
)
session = onnxruntime.InferenceSession("./toy.onnx")
inputs = np.random.randint(100, 200, (2, 8, 2))
ort_inputs = {'input1': inputs}
ort_outs = session.run(None, ort_inputs)
I can export Pytoch model to ONNX successfully, but when I change input batch size I got errors.
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running Split node. Name:'Split_3' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={16,2} NumOutputs=8 Num entries in 'split' (must equal number of outputs) was 8 Sum of sizes in 'split' (must equal size of selected axis) was 8
I know this error caused by loop through a dynamic sized tensor, but I don’t know how to solve this problem. Note, I cannot move this operation out of model due to some restrictions.
Any suggestions will help me, thanks.