Export ONNX with static Shape instead of onnx::Shape operator

I am tuning my pytorch code to output static-shaped ONNX IR. Since my target ONNX runtime does not support onnx::Shape, I’d like to export IR with hard-coded shape. Is there a formal way to do so?

In the two cases below. f1 is the normal pytorch code would output an ONNX IR with dynamic Shape operator while the second one f2 does not which is desired. This is done by casting the result of size() to an integer list.

I feel this is somewhat hacky. Is there a better way to do so, such as a flag for tracer or some post tracing optimization pass.

PS. I’ve tried a_sz.detach_() and it doesn’t change the output

def test_jit_size():
    a=torch.tensor([1,2,3.])
    def f1(a):
        a_sz = a.size()[0]
        b=torch.ones(a_sz)
        return b

    def f2(a):
        a_sz = a.size()[0]
        a_sz = int(a_sz)
        b=torch.ones(a_sz)
        return b

    def get_onnx_graph(f, input):
        trace, z = torch.jit.get_trace_graph(f, input)
        torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
        return trace

    print("f1/f2")
    print(get_onnx_graph(f1, (a)))
    print(get_onnx_graph(f2, (a)))

Output (with pytorch 1.0.0a0+bf99ffc)

graph(%0 : Float(3)) {
  %1 : Long() = onnx::Constant[value={0}]()
  %2 : Dynamic = onnx::Shape(%0)
  %3 : Long() = onnx::Gather[axis=0](%2, %1)
  %4 : Dynamic = onnx::Unsqueeze[axes=[0]](%3)
  %5 : int[] = onnx::Concat[axis=0](%4)
  %6 : Float(3) = onnx::ConstantFill[dtype=1, input_as_shape=1, value=1](%5)
  return (%6);
}

graph(%0 : Float(3)) {
  %1 : Float(3) = onnx::Constant[value= 1  1  1 [ CPUFloatType{3} ]]()
  return (%1);
}

Thanks!

2 Likes

I would love to see this also! :slight_smile: