I am tuning my pytorch code to output static-shaped ONNX IR. Since my target ONNX runtime does not support onnx::Shape, I’d like to export IR with hard-coded shape. Is there a formal way to do so?
In the two cases below. f1
is the normal pytorch code would output an ONNX IR with dynamic Shape operator while the second one f2
does not which is desired. This is done by casting the result of size() to an integer list.
I feel this is somewhat hacky. Is there a better way to do so, such as a flag for tracer or some post tracing optimization pass.
PS. I’ve tried a_sz.detach_()
and it doesn’t change the output
def test_jit_size():
a=torch.tensor([1,2,3.])
def f1(a):
a_sz = a.size()[0]
b=torch.ones(a_sz)
return b
def f2(a):
a_sz = a.size()[0]
a_sz = int(a_sz)
b=torch.ones(a_sz)
return b
def get_onnx_graph(f, input):
trace, z = torch.jit.get_trace_graph(f, input)
torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
return trace
print("f1/f2")
print(get_onnx_graph(f1, (a)))
print(get_onnx_graph(f2, (a)))
Output (with pytorch 1.0.0a0+bf99ffc)
graph(%0 : Float(3)) {
%1 : Long() = onnx::Constant[value={0}]()
%2 : Dynamic = onnx::Shape(%0)
%3 : Long() = onnx::Gather[axis=0](%2, %1)
%4 : Dynamic = onnx::Unsqueeze[axes=[0]](%3)
%5 : int[] = onnx::Concat[axis=0](%4)
%6 : Float(3) = onnx::ConstantFill[dtype=1, input_as_shape=1, value=1](%5)
return (%6);
}
graph(%0 : Float(3)) {
%1 : Float(3) = onnx::Constant[value= 1 1 1 [ CPUFloatType{3} ]]()
return (%1);
}
Thanks!