Does anyone know why this code gives the following error when loading the ONNX model?
@torch.jit.script
def bar(x):
zero = torch.tensor(0.0, dtype=torch.float32)
one = torch.tensor(1.0, dtype=torch.float32)
if x.eq(zero):
y = zero
else:
y = one
return y
class Foo(nn.Module):
def forward(self, x):
return bar(x)
foo = Foo()
dummy_x = torch.tensor(0.0, dtype=torch.float32)
torch.onnx.export(foo, dummy_x, "./foo.onnx", input_names=["x"], output_names=["y"])
foo_onnx = onnxruntime.InferenceSession("./foo.onnx")
InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Failed to load model with error: D:\2\s\onnxruntime\core\graph\graph.cc:912 onnxruntime::Graph::InitializeStateFromModelFileGraphProto This is an invalid model. Graph output (1) does not exist in the graph.
I just tested this repro with PyTorch and ONNXRuntime nightly.
The error I see is:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type ‘tensor(float)’ of input parameter (x) of operator (Equal) in node (Equal_3) is invalid.
However, by changing dtype to torch.int in the test, this ONNXRuntime error is no longer thrown.
Looks like ONNX Equal op takes float input data type, so this might be an issue with ONNXRuntime.
Will follow op on this.