Export model to ONNX

This is the code I’m trying to run:

trained_model = model
trained_model.load_state_dict(torch.load('net.pth'))
dummy_input = Variable(torch.randn(1, 1, 28, 28)) 
torch.onnx.export(trained_model, dummy_input, "net.onnx") 

This is the error I am getting:

RuntimeError                              Traceback (most recent call last)
<ipython-input-167-1d29eedcb0f4> in <module>
      2 trained_model.load_state_dict(torch.load('net.pth'))
      3 dummy_input = Variable(torch.randn(1, 1, 28, 28))
----> 4 torch.onnx.export(trained_model, dummy_input, "net.onnx")

~\Anaconda3\lib\site-packages\torch\onnx\__init__.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
    269 
    270     from torch.onnx import utils
--> 271     return utils.export(model, args, f, export_params, verbose, training,
    272                         input_names, output_names, aten, export_raw_ir,
    273                         operator_export_type, opset_version, _retain_param_name,

~\Anaconda3\lib\site-packages\torch\onnx\utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format)
     86         else:
     87             operator_export_type = OperatorExportTypes.ONNX
---> 88     _export(model, args, f, export_params, verbose, training, input_names, output_names,
     89             operator_export_type=operator_export_type, opset_version=opset_version,
     90             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,

~\Anaconda3\lib\site-packages\torch\onnx\utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, _retain_param_name, do_constant_folding, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, enable_onnx_checker, use_external_data_format, onnx_shape_inference, use_new_jit_passes)
    692 
    693             graph, params_dict, torch_out = \
--> 694                 _model_to_graph(model, args, verbose, input_names,
    695                                 output_names, operator_export_type,
    696                                 example_outputs, _retain_param_name,

~\Anaconda3\lib\site-packages\torch\onnx\utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, _retain_param_name, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, use_new_jit_passes, dynamic_axes)
    455         example_outputs = (example_outputs,)
    456 
--> 457     graph, params, torch_out, module = _create_jit_graph(model, args,
    458                                                          _retain_param_name,
    459                                                          use_new_jit_passes)

~\Anaconda3\lib\site-packages\torch\onnx\utils.py in _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes)
    418         return graph, params, torch_out, None
    419     else:
--> 420         graph, torch_out = _trace_and_get_graph_from_model(model, args)
    421         state_dict = _unique_state_dict(model)
    422         params = list(state_dict.values())

~\Anaconda3\lib\site-packages\torch\onnx\utils.py in _trace_and_get_graph_from_model(model, args)
    378 
    379     trace_graph, torch_out, inputs_states = \
--> 380         torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
    381     warn_on_static_input_change(inputs_states)
    382 

~\Anaconda3\lib\site-packages\torch\jit\_trace.py in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
   1137     if not isinstance(args, tuple):
   1138         args = (args,)
-> 1139     outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
   1140     return outs

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~\Anaconda3\lib\site-packages\torch\jit\_trace.py in forward(self, *args)
    123                 return tuple(out_vars)
    124 
--> 125         graph, out = torch._C._create_graph_by_tracing(
    126             wrapper,
    127             in_vars + module_state,

~\Anaconda3\lib\site-packages\torch\jit\_trace.py in wrapper(*args)
    114             if self._return_inputs_states:
    115                 inputs_states.append(_unflatten(in_args, in_desc))
--> 116             outs.append(self.inner(*trace_inputs))
    117             if self._return_inputs_states:
    118                 inputs_states[0] = (inputs_states[0], trace_inputs)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    885 
    886         if torch._C._get_tracing_state():
--> 887             result = self._slow_forward(*input, **kwargs)
    888         else:
    889             result = self.forward(*input, **kwargs)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _slow_forward(self, *input, **kwargs)
    858                 recording_scopes = False
    859         try:
--> 860             result = self.forward(*input, **kwargs)
    861         finally:
    862             if recording_scopes:

<ipython-input-53-69d3577c8c5b> in forward(self, x)
      8 
      9     def forward(self, x):
---> 10         x = self.char_embedding(x)
     11         output, hidden = self.lstm(x)
     12         hidden = torch.cat((hidden[0][-2,:,:], hidden[0][-1,:,:]), dim=1)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    885 
    886         if torch._C._get_tracing_state():
--> 887             result = self._slow_forward(*input, **kwargs)
    888         else:
    889             result = self.forward(*input, **kwargs)

~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _slow_forward(self, *input, **kwargs)
    858                 recording_scopes = False
    859         try:
--> 860             result = self.forward(*input, **kwargs)
    861         finally:
    862             if recording_scopes:

~\Anaconda3\lib\site-packages\torch\nn\modules\sparse.py in forward(self, input)
    154 
    155     def forward(self, input: Tensor) -> Tensor:
--> 156         return F.embedding(
    157             input, self.weight, self.padding_idx, self.max_norm,
    158             self.norm_type, self.scale_grad_by_freq, self.sparse)

~\Anaconda3\lib\site-packages\torch\nn\functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1914         # remove once script supports set_grad_enabled
   1915         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1916     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1917 
   1918 

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

I’d appreciate any help in exporting my model, this is my first time using ONNX. My model is a 5 layer LSTM that takes hostnames as strings and assigns them to one of 35/36 different groupids. Any tutorials I find for exporting are too complex for me to follow.

The error points to a type mismatch for the nn.Embedding module, which expects a LongTensor as the input, while you are apparently using a FloatTensor, so you would need to change the dtype of the input.
A quick solution would be to transform it via x = x.long().

Also note, that Variables are deprecated since PyTorch 0.4, so you can use tensors now. :wink:

1 Like