This is the code I’m trying to run:
trained_model = model
trained_model.load_state_dict(torch.load('net.pth'))
dummy_input = Variable(torch.randn(1, 1, 28, 28))
torch.onnx.export(trained_model, dummy_input, "net.onnx")
This is the error I am getting:
RuntimeError Traceback (most recent call last)
<ipython-input-167-1d29eedcb0f4> in <module>
2 trained_model.load_state_dict(torch.load('net.pth'))
3 dummy_input = Variable(torch.randn(1, 1, 28, 28))
----> 4 torch.onnx.export(trained_model, dummy_input, "net.onnx")
~\Anaconda3\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
269
270 from torch.onnx import utils
--> 271 return utils.export(model, args, f, export_params, verbose, training,
272 input_names, output_names, aten, export_raw_ir,
273 operator_export_type, opset_version, _retain_param_name,
~\Anaconda3\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
86 else:
87 operator_export_type = OperatorExportTypes.ONNX
---> 88 _export(model, args, f, export_params, verbose, training, input_names, output_names,
89 operator_export_type=operator_export_type, opset_version=opset_version,
90 _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
~\Anaconda3\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
692
693 graph, params_dict, torch_out = \
--> 694 _model_to_graph(model, args, verbose, input_names,
695 output_names, operator_export_type,
696 example_outputs, _retain_param_name,
~\Anaconda3\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, use_new_jit_passes, dynamic_axes)
455 example_outputs = (example_outputs,)
456
--> 457 graph, params, torch_out, module = _create_jit_graph(model, args,
458 _retain_param_name,
459 use_new_jit_passes)
~\Anaconda3\lib\site-packages\torch\onnx\utils.py in _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes)
418 return graph, params, torch_out, None
419 else:
--> 420 graph, torch_out = _trace_and_get_graph_from_model(model, args)
421 state_dict = _unique_state_dict(model)
422 params = list(state_dict.values())
~\Anaconda3\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args)
378
379 trace_graph, torch_out, inputs_states = \
--> 380 torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
381 warn_on_static_input_change(inputs_states)
382
~\Anaconda3\lib\site-packages\torch\jit\_trace.py in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
1137 if not isinstance(args, tuple):
1138 args = (args,)
-> 1139 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
1140 return outs
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
~\Anaconda3\lib\site-packages\torch\jit\_trace.py in forward(self, *args)
123 return tuple(out_vars)
124
--> 125 graph, out = torch._C._create_graph_by_tracing(
126 wrapper,
127 in_vars + module_state,
~\Anaconda3\lib\site-packages\torch\jit\_trace.py in wrapper(*args)
114 if self._return_inputs_states:
115 inputs_states.append(_unflatten(in_args, in_desc))
--> 116 outs.append(self.inner(*trace_inputs))
117 if self._return_inputs_states:
118 inputs_states[0] = (inputs_states[0], trace_inputs)
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
885
886 if torch._C._get_tracing_state():
--> 887 result = self._slow_forward(*input, **kwargs)
888 else:
889 result = self.forward(*input, **kwargs)
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _slow_forward(self, *input, **kwargs)
858 recording_scopes = False
859 try:
--> 860 result = self.forward(*input, **kwargs)
861 finally:
862 if recording_scopes:
<ipython-input-53-69d3577c8c5b> in forward(self, x)
8
9 def forward(self, x):
---> 10 x = self.char_embedding(x)
11 output, hidden = self.lstm(x)
12 hidden = torch.cat((hidden[0][-2,:,:], hidden[0][-1,:,:]), dim=1)
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
885
886 if torch._C._get_tracing_state():
--> 887 result = self._slow_forward(*input, **kwargs)
888 else:
889 result = self.forward(*input, **kwargs)
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _slow_forward(self, *input, **kwargs)
858 recording_scopes = False
859 try:
--> 860 result = self.forward(*input, **kwargs)
861 finally:
862 if recording_scopes:
~\Anaconda3\lib\site-packages\torch\nn\modules\sparse.py in forward(self, input)
154
155 def forward(self, input: Tensor) -> Tensor:
--> 156 return F.embedding(
157 input, self.weight, self.padding_idx, self.max_norm,
158 self.norm_type, self.scale_grad_by_freq, self.sparse)
~\Anaconda3\lib\site-packages\torch\nn\functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
1914 # remove once script supports set_grad_enabled
1915 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1916 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
1917
1918
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)
I’d appreciate any help in exporting my model, this is my first time using ONNX. My model is a 5 layer LSTM that takes hostnames as strings and assigns them to one of 35/36 different groupids. Any tutorials I find for exporting are too complex for me to follow.