QAT model convert onnx is error!
Traceback (most recent call last):
File "../../tools/deployment/convert_onnx_qat.py", line 141, in <module>
main()
File "../../tools/deployment/convert_onnx_qat.py", line 135, in main
keep_initializers_as_inputs=True,
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 698, in _export
dynamic_axes=dynamic_axes)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 456, in _model_to_graph
use_new_jit_passes)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 417, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 377, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/jit/_trace.py", line 130, in forward
self._force_outplace,
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/shanshaojie/project/mmdetection/mmdet/models/detectors/single_stage.py", line 53, in forward_dummy
x = self.extract_feat(img)
File "/home/shanshaojie/project/mmdetection/mmdet/models/detectors/single_stage.py", line 43, in extract_feat
x = self.backbone(img)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/shanshaojie/project/mmdetection/projects/pruner/pmmdet/pbackbones/yolox_quant.py", line 284, in forward
fake_quant_0 = self.fake_quant_0(input_0_f)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 893, in _call_impl
hook_result = hook(self, input, result)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/quantization/quantize.py", line 83, in _observer_forward_hook
return self.activation_post_process(output)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/quantization/fake_quantize.py", line 130, in forward
self.activation_post_process(X.detach())
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/quantization/observer.py", line 480, in forward
self.min_val.resize_(min_val.shape)
RuntimeError: _Map_base::at
import warnings
import argparse
import torch
import mmcv
import os.path as osp
import sys
import pkgutil
from mmdet.models.builder import build_detector
from mmcv.runner import load_checkpoint
from mmdet.core.evaluation import get_classes
import torch.quantization as quantization
from tinynn.graph.quantization.quantizer import QATQuantizer
from tinynn.graph.tracer import model_tracer
def register_new_modules(config):
filename = osp.abspath(osp.expanduser(config))
mmcv.check_file_exist(filename)
project_directory = osp.dirname(filename)
assert osp.isdir(project_directory), \
"The project must have been created in the projects directory."
sys.path.insert(0, project_directory)
for importer, name, _ in pkgutil.iter_modules([project_directory]):
if osp.isdir(name):
__import__(name.replace('/', '.'))
def init_detector(config, checkpoint=None, device="cuda:0"):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
Returns:
nn.Module: The constructed detector.
"""
register_new_modules(config)
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
config.model.pretrained = None
config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get("test_cfg"))
# import pdb; pdb.set_trace()
if config.get("quantized", False):
from mmdet.core.hook.qat_hook import set_in_quantization
from mmdet.core.hook.qat_hook import _prepare_qat_model
set_in_quantization()
BACKEND = config.get("quantized_engine", "qnnpack")
if BACKEND == 'onnx':
torch.backends.quantized.engine = "qnnpack"
else:
torch.backends.quantized.engine = BACKEND
qathook_cfg = config.custom_hooks[-1]
assert qathook_cfg.pop('type') == 'QuantifyHook', "please set qat cfg"
input_shape = qathook_cfg.pop('input_shape')
qathook_cfg.pop('num_observer_update_epochs')
qathook_cfg.pop('num_batch_norm_update_epochs')
_prepare_qat_model(model, input_shape, **qathook_cfg)
model.eval()
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint)
if ("meta" in checkpoint) and ("CLASSES" in checkpoint["meta"]):
model.CLASSES = checkpoint["meta"]["CLASSES"]
else:
warnings.warn(
"Class names are not saved in the checkpoint's "
"meta data, use COCO classes by default."
)
model.CLASSES = get_classes("coco")
# model.cfg = config # save the config in the model for convenience
# if config.get("quantized", False):
# model.backbone = torch.quantization.convert(model.backbone)
model.to(device)
# model.eval()
return model
def parse_args():
"""add some parameters."""
parser = argparse.ArgumentParser(description="Train a detector")
parser.add_argument("--config", help="train config file path", default=None)
parser.add_argument(
"--checkpoint", help="checkpoint file of the model", default=None
)
parser.add_argument("--out", help="output ONNX file", default=None)
parser.add_argument("--input_names", help="onnx input name", nargs="+", default=["input"])
parser.add_argument("--output_names", help="onnx output name", nargs="+", default=["output"])
parser.add_argument(
"--shape", type=int, nargs="+", default=[480], help="input image size"
)
args = parser.parse_args()
return args
def main():
args = parse_args()
if len(args.shape) == 1:
img_shape = (1, 3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
img_shape = (1, 3) + tuple(args.shape)
elif len(args.shape) == 4:
img_shape = tuple(args.shape)
else:
raise ValueError("invalid input shape")
dummy_input = torch.randn(*img_shape, device="cpu")
model = init_detector(args.config, args.checkpoint, device="cpu")
model.forward = model.forward_dummy
torch.onnx.export(
model,
dummy_input,
args.out,
verbose=False,
input_names=args.input_names,
output_names=args.output_names,
opset_version=13,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
keep_initializers_as_inputs=True,
)
print("END")
if __name__ == "__main__":
main()