Hi!
I am having an issue when exporting of PyTorch GNN model to ONNX. Here is my export code:
torch.onnx.export(
model=model,
args=(x_dict, edge_index_dict, edge_attr_dict, {}),
f=save_path,
verbose=False,
input_names=["x_dict", "edge_index_dict", "edge_attr_dict"],
output_names=["out"],
)
x_dict, edge_index_dict, edge_attr_dict
are of type Dict[str, torch.Tensor]
(hetero_data is formed like this)
In addition to 3 inputs in my model’s forward , torch.onnx.export generates 4 additional inputs and when I try to use exported model with onnxruntime I get ValueError:
ValueError: Required inputs (['edge_index', 'edge_index.5', 'edge_index.3', 'onnx::Reshape_9']) are missing from input feed (['x_dict', 'edge_index_dict', 'edge_attr_dict']).
Setting input dict as dynamic_inputs did not help
here is the jit trace with strict=False .code output:
def forward(self,
argument_1: Dict[str, Tensor],
argument_2: Dict[str, Tensor],
argument_3: Dict[str, Tensor]) -> Dict[str, Tensor]:
state_encoder = self.state_encoder
x = argument_1["game_vertex"]
x0 = argument_1["state_vertex"]
edge_index = argument_2["game_vertex to game_vertex"]
edge_index0 = argument_2["game_vertex in state_vertex"]
edge_index1 = argument_2["game_vertex history state_vertex"]
edge_index2 = argument_2["state_vertex parent_of state_vertex"]
edge_weight = argument_3["game_vertex history state_vertex"]
_0 = (state_encoder).forward(x, edge_index, x0, edge_index2, edge_index1, edge_weight, edge_index0, )
_1 = {"state_vertex": _0, "game_vertex": x}
return _1
How can I get rid of those inputs?