Hi.
I want to convert darts/cnn’s model to TFlite, finaly.
First of all, I tried to convert it to ONNX by below code.
import torch
import torch.nn as nn
import genotypes
from model import NetworkCIFAR as Network
genotype = eval("genotypes.%s" % 'DARTS')
model = Network(36, 10, 20, True, genotype)
model.load_state_dict(torch.load('./weights.pt'))
model = model.cuda()
onnx_model_path = './darts_model.onnx'
dummy_input = torch.randn(8,3,32,32)
input_names = ['image_array']
output_names = ['category']
torch.onnx.export(model,dummy_input, onnx_model_path,
input_names=input_names, output_names=output_names)
However, it couldn’t convert.
Error is below.
Traceback (most recent call last):
File "<stdin>", line 2, in <module>
File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/__init__.py", line 168, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/utils.py", line 69, in export
use_external_data_format=use_external_data_format)
File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/utils.py", line 488, in _export
fixed_batch_size=fixed_batch_size)
File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/utils.py", line 334, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/utils.py", line 291, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True)
File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/jit/__init__.py", line 278, in _get_trace_graph
outs = ONNXTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/XXXX_darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/jit/__init__.py", line 361, in forward
self._force_outplace,
File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/jit/__init__.py", line 351, in wrapper
out_vars, _ = _flatten(outs)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type NoneType
Does onnx.export not correspond to darts?
Could you please tell me how to fix it to convert to ONNX?
Finally, I am asking this same question in darts’s issue, sorry.
Thank you!
My environment
- Ubuntu 16.04
- Python 3.6.10
- CUDA 9.0
- Pytorch 0.3.1(to search model), 1.5.1(to convert to ONNX)