Find below my model, which includes conditional statements in forward block
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear( 1, 3 )
self.fc2 = nn.Linear( 3, 10 )
self.fc3 = nn.Linear( 10, 2 )
def forward(self,x):
if x.shape[0] ==1 :
x = self.fc2( self.fc1(x) )
return x
else:
x = self.fc3( self.fc2( self.fc1(x) ) )
return x
To handle the conditional flow statements I have converted the model to torchscript
model = Net()
script_model = torch.jit.script( model )
Pass through dummy data
data = torch.randn((1,1))
outputs = script_model(data)
ONNX model export
onnx_model_path = "saved_models/model.onnx"
torch.onnx.export(script_model, data, onnx_model_path, opset_version=11, input_names=["input"] , example_outputs= outputs,
dynamic_axes={ "input":{0:"batch_size"} }, output_names=["output"])
ONNX checker throws no error, However when I do inference I am having a crash with below error
onnx_session(onnx_model_path)
Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from saved_models/model.onnx failed:Node (If_5) Op (If) [TypeInferenceError] Graph attribute inferencing failed: Node (If_10) Op (If) [TypeInferenceError] Graph attribute inferencing failed: This is an invalid model. Type Error: Type 'tensor(int64)' of input parameter (25) of operator (If) in node (If_14) is invalid.
The JIT converted model is working fine, Is there any limitation of using tensor.size or shape functions in conditional statements when converting to ONNX ?
Please help me figure out :)) Thanks