ONNX inference fails for a simple model structure with conditional statements

Find below my model, which includes conditional statements in forward block

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear( 1, 3 )
        self.fc2 = nn.Linear( 3, 10 )
        self.fc3 = nn.Linear( 10, 2 )
        
    def forward(self,x):
        
        if x.shape[0] ==1 :
            x = self.fc2( self.fc1(x)  )
            return x 
        else:
            x = self.fc3( self.fc2( self.fc1(x)  ) )
            return x 

To handle the conditional flow statements I have converted the model to torchscript

model = Net()
script_model = torch.jit.script( model )

Pass through dummy data

data = torch.randn((1,1))
outputs = script_model(data) 

ONNX model export

onnx_model_path = "saved_models/model.onnx"
torch.onnx.export(script_model, data, onnx_model_path, opset_version=11, input_names=["input"] ,  example_outputs= outputs,
                  dynamic_axes={ "input":{0:"batch_size"} }, output_names=["output"])

ONNX checker throws no error, However when I do inference I am having a crash with below error

onnx_session(onnx_model_path)

Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from saved_models/model.onnx failed:Node (If_5) Op (If) [TypeInferenceError] Graph attribute inferencing failed: Node (If_10) Op (If) [TypeInferenceError] Graph attribute inferencing failed: This is an invalid model. Type Error: Type 'tensor(int64)' of input parameter (25) of operator (If) in node (If_14) is invalid.

The JIT converted model is working fine, Is there any limitation of using tensor.size or shape functions in conditional statements when converting to ONNX ?

Please help me figure out :)) Thanks

Hi there,

for torch.onnx you want to pass the torch model directly, not the jit model. this works fine for me:

import torch
import torch.nn as nn
import onnx


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear( 1, 3 )
        self.fc2 = nn.Linear( 3, 10 )
        self.fc3 = nn.Linear( 10, 2 )

    def forward(self,x):

        if x.shape[0] ==1 :
            x = self.fc2( self.fc1(x)  )
            return x
        else:
            x = self.fc3( self.fc2( self.fc1(x)  ) )
            return x

model = Net()

data = torch.randn((1,1))
outputs = model(data)
onnx_model_path = "model.onnx"
torch.onnx.export(model, data, onnx_model_path, opset_version=11, input_names=["input"] ,  example_outputs= outputs,
                  dynamic_axes={ "input":{0:"batch_size"} }, output_names=["output"])