QAT model convert onnx

QAT model convert onnx is error!

Traceback (most recent call last):
  File "../../tools/deployment/convert_onnx_qat.py", line 141, in <module>
    main()
  File "../../tools/deployment/convert_onnx_qat.py", line 135, in main
    keep_initializers_as_inputs=True,
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/__init__.py", line 276, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 94, in export
    use_external_data_format=use_external_data_format)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 698, in _export
    dynamic_axes=dynamic_axes)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 456, in _model_to_graph
    use_new_jit_passes)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 417, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/onnx/utils.py", line 377, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/jit/_trace.py", line 130, in forward
    self._force_outplace,
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/jit/_trace.py", line 116, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/shanshaojie/project/mmdetection/mmdet/models/detectors/single_stage.py", line 53, in forward_dummy
    x = self.extract_feat(img)
  File "/home/shanshaojie/project/mmdetection/mmdet/models/detectors/single_stage.py", line 43, in extract_feat
    x = self.backbone(img)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/shanshaojie/project/mmdetection/projects/pruner/pmmdet/pbackbones/yolox_quant.py", line 284, in forward
    fake_quant_0 = self.fake_quant_0(input_0_f)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 893, in _call_impl
    hook_result = hook(self, input, result)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/quantization/quantize.py", line 83, in _observer_forward_hook
    return self.activation_post_process(output)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/quantization/fake_quantize.py", line 130, in forward
    self.activation_post_process(X.detach())
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/shanshaojie/anaconda3/lib/python3.6/site-packages/torch/quantization/observer.py", line 480, in forward
    self.min_val.resize_(min_val.shape)
RuntimeError: _Map_base::at
import warnings
import argparse
import torch
import mmcv
import os.path as osp
import sys
import pkgutil
from mmdet.models.builder import build_detector
from mmcv.runner import load_checkpoint
from mmdet.core.evaluation import get_classes
import torch.quantization as quantization
from tinynn.graph.quantization.quantizer import QATQuantizer
from tinynn.graph.tracer import model_tracer


def register_new_modules(config):
    filename = osp.abspath(osp.expanduser(config))
    mmcv.check_file_exist(filename)
    project_directory = osp.dirname(filename)
    assert osp.isdir(project_directory), \
        "The project must have been created in the projects directory."
    sys.path.insert(0, project_directory)
    for importer, name, _ in pkgutil.iter_modules([project_directory]):
        if osp.isdir(name):
            __import__(name.replace('/', '.'))


def init_detector(config, checkpoint=None, device="cuda:0"):
    """Initialize a detector from config file.

    Args:
        config (str or :obj:`mmcv.Config`): Config file path or the config
            object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.

    Returns:
        nn.Module: The constructed detector.
    """
    register_new_modules(config)
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    config.model.pretrained = None
    config.model.train_cfg = None
    model = build_detector(config.model, test_cfg=config.get("test_cfg"))
    # import pdb; pdb.set_trace()
    if config.get("quantized", False):
        from mmdet.core.hook.qat_hook import set_in_quantization
        from mmdet.core.hook.qat_hook import _prepare_qat_model
        set_in_quantization()
        BACKEND = config.get("quantized_engine", "qnnpack")
        if BACKEND == 'onnx':
            torch.backends.quantized.engine = "qnnpack"
        else:
            torch.backends.quantized.engine = BACKEND
        qathook_cfg = config.custom_hooks[-1]
        assert qathook_cfg.pop('type') == 'QuantifyHook', "please set qat cfg"
        input_shape = qathook_cfg.pop('input_shape')
        qathook_cfg.pop('num_observer_update_epochs')
        qathook_cfg.pop('num_batch_norm_update_epochs')
        _prepare_qat_model(model, input_shape, **qathook_cfg)
        model.eval()
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)
        if ("meta" in checkpoint) and ("CLASSES" in checkpoint["meta"]):
            model.CLASSES = checkpoint["meta"]["CLASSES"]
        else:
            warnings.warn(
                "Class names are not saved in the checkpoint's "
                "meta data, use COCO classes by default."
            )
            model.CLASSES = get_classes("coco")
    # model.cfg = config  # save the config in the model for convenience
    # if config.get("quantized", False):
    #     model.backbone = torch.quantization.convert(model.backbone)
    model.to(device)
    # model.eval()
    return model


def parse_args():
    """add some parameters."""
    parser = argparse.ArgumentParser(description="Train a detector")
    parser.add_argument("--config", help="train config file path", default=None)
    parser.add_argument(
        "--checkpoint", help="checkpoint file of the model", default=None
    )
    parser.add_argument("--out", help="output ONNX file", default=None)
    parser.add_argument("--input_names", help="onnx input name", nargs="+", default=["input"])
    parser.add_argument("--output_names", help="onnx output name", nargs="+", default=["output"])
    parser.add_argument(
        "--shape", type=int, nargs="+", default=[480], help="input image size"
    )
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    if len(args.shape) == 1:
        img_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        img_shape = (1, 3) + tuple(args.shape)
    elif len(args.shape) == 4:
        img_shape = tuple(args.shape)
    else:
        raise ValueError("invalid input shape")

    dummy_input = torch.randn(*img_shape, device="cpu")
    model = init_detector(args.config, args.checkpoint, device="cpu")
    model.forward = model.forward_dummy

    torch.onnx.export(
        model,
        dummy_input,
        args.out,
        verbose=False,
        input_names=args.input_names,
        output_names=args.output_names,
        opset_version=13,
        training=torch.onnx.TrainingMode.EVAL,
        do_constant_folding=True,
        keep_initializers_as_inputs=True,
    )
    print("END")


if __name__ == "__main__":
    main()