[ONNX] Quantized fused Conv2d won't trace

#onnx #jit #quantization

Hi, I am very confused.

While tracing to ONNX my quantized model faced an error. This happens with fused QuantizedConvReLU2d. I use OperatorExportTypes.ONNX_ATEN_FALLBACK.
Pytorch version is 1.6.0.dev20200520

Traceback (most recent call last):
  File "./tools/caffe2_converter.py", line 115, in <module>
    caffe2_model = export_caffe2_model(cfg, model, first_batch)
  File "/root/some_detectron2/detectron2/export/api.py", line 157, in export_caffe2_model
    return Caffe2Tracer(cfg, model, inputs).export_caffe2()
  File "/root/some_detectron2/detectron2/export/api.py", line 95, in export_caffe2
    predict_net, init_net = export_caffe2_detection_model(model, inputs)
  File "/root/some_detectron2/detectron2/export/caffe2_export.py", line 144, in export_caffe2_detection_model
    onnx_model = export_onnx_model(model, (tensor_inputs,))
  File "/root/some_detectron2/detectron2/export/caffe2_export.py", line 63, in export_onnx_model
    export_params=True,
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 172, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 530, in _export
    fixed_batch_size=fixed_batch_size)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 366, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 319, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/jit/__init__.py", line 284, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/jit/__init__.py", line 372, in forward
    self._force_outplace,
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/jit/__init__.py", line 358, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 575, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 561, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/root/some_detectron2/detectron2/export/caffe2_modeling.py", line 319, in forward
    features = self._wrapped_model.backbone(images.tensor)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 575, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 561, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/DensePose_ADASE/densepose/modeling/quantize_caffe2.py", line 166, in new_forward
    p5, p4, p3, p2 = self.bottom_up(x)  # top->down
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 575, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 561, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/timm/models/efficientnet.py", line 350, in forward
    x = self.conv_stem(x)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 575, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 561, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py", line 71, in forward
    input, self._packed_params, self.scale, self.zero_point)
RuntimeError: Tried to trace <__torch__.torch.classes.quantized.Conv2dPackedParamsBase object at 0x5600474e9670> but it is not part of the active trace. Modules that are called during a trace must be registered 
as submodules of the thing being traced.

May presense of pre_forward hooks in self.bottom_up(x) (but not the self.conv_stem(x)) affect tracing such way?
Model were QAT with preserving hooks from commit https://github.com/pytorch/pytorch/pull/37233
Also PT -> ONNX -> Caffe2 exporting works on this very model without quantization patching

P.S. here’s also a warning

/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/quantized/modules/utils.py:10: UserWarning: 0quantize_tensor_per_tensor_affine current rounding mode is not set to round-to-nearest-ties-to-e
ven (FE_TONEAREST). This will cause accuracy issues in quantized models. (Triggered internally at  /opt/conda/conda-bld/pytorch_1589958443755/work/aten/src/ATen/native/quantized/affine_quantizer.cpp:25.)
  float(wt_scale), int(wt_zp), torch.qint8)

cc @James_Reed is this related to TorchBind object?

@zetyquickly could you give a minimal repo of the issue?

@jerryzh168 @James_Reed

I’ve prepared a repro. It might be not minimal but it mocks the pipeline I use.
Two files:

# network.py

import torch

class ConvModel(torch.nn.Module):
    def __init__(self):
        super(ConvModel, self).__init__()
        self.conv_stem = torch.nn.Conv2d(
            3, 5, 2, bias=True
        ).to(dtype=torch.float)

        self.bn1 = torch.nn.BatchNorm2d(5)
        self.act1 = torch.nn.ReLU()

    def forward(self, x):
        x = self.conv_stem(x)
        x = self.bn1(x)
        x = self.act1(x)
        return x
# actions.py

import torch
import io
import onnx
from torch.onnx import OperatorExportTypes


def ConvModel_decorate(cls):

    def fuse(self):
        torch.quantization.fuse_modules(
            self, 
            ['conv_stem', 'bn1', 'act1'], 
            inplace=True
        )

    cls.fuse = fuse
    return cls

def fuse_modules(module):
    module_output = module
    if callable(getattr(module_output, "fuse", None)):
        module_output.fuse()
    for name, child in module.named_children():
        new_child = fuse_modules(child)
        if new_child is not child:
            module_output.add_module(name, new_child)
    return module_output

def create_and_update_model():
    import network
    network.ConvModel = ConvModel_decorate(network.ConvModel)
    model = network.ConvModel()
    backend = 'qnnpack'
    model = fuse_modules(model)
    model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
    torch.backends.quantized.engine = backend
    torch.quantization.prepare_qat(model, inplace=True)
    model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
    return model

def QAT(model):
    N = 100
    for idx in range(N):
        input_tensor = torch.rand(1, 3, 6, 6)
        model(input_tensor)
    return model

if __name__ == '__main__':
    model = create_and_update_model()
    model = QAT(model)
    torch.quantization.convert(model, inplace=True)
    
    model.eval()
    inputs = torch.rand(1, 3, 6, 6)
    # Export the model to ONNX
    with torch.no_grad():
        with io.BytesIO() as f:
            torch.onnx.export(
                model,
                inputs,
                f,
                opset_version=11,
                operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
                verbose=True,  # NOTE: uncomment this for debugging
                export_params=True,
            )
            onnx_model = onnx.load_from_string(f.getvalue())

Error:

(pytorch-gpu) root@ca7d6f51c4c7:~/some_detectron2# /root/anaconda2/envs/pytorch-gpu/bin/python /root/some_detectron2/min_repro/actions.py
/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/quantized/modules/utils.py:10: UserWarning: 0quantize_tensor_per_tensor_affine current rounding mode is not set to round-to-nearest-ties-to-even (FE_TONEAREST). This will cause accuracy issues in quantized models. (Triggered internally at  /opt/conda/conda-bld/pytorch_1589958443755/work/aten/src/ATen/native/quantized/affine_quantizer.cpp:25.)
  float(wt_scale), int(wt_zp), torch.qint8)
/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py:243: UserWarning: `add_node_names' can be set to True only when 'operator_export_type' is `ONNX`. Since 'operator_export_type' is not set to 'ONNX', `add_node_names` argument will be ignored.
  "`{}` argument will be ignored.".format(arg_name, arg_name))
/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py:243: UserWarning: `do_constant_folding' can be set to True only when 'operator_export_type' is `ONNX`. Since 'operator_export_type' is not set to 'ONNX', `do_constant_folding` argument will be ignored.
  "`{}` argument will be ignored.".format(arg_name, arg_name))
Traceback (most recent call last):
  File "/root/some_detectron2/min_repro/actions.py", line 65, in <module>
    export_params=True,
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 172, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 530, in _export
    fixed_batch_size=fixed_batch_size)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 366, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 319, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/jit/__init__.py", line 284, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/jit/__init__.py", line 372, in forward
    self._force_outplace,
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/jit/__init__.py", line 358, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 575, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 561, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/some_detectron2/min_repro/network.py", line 14, in forward
    x = self.conv_stem(x)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 575, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 561, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py", line 71, in forward
    input, self._packed_params, self.scale, self.zero_point)
RuntimeError: Tried to trace <__torch__.torch.classes.quantized.Conv2dPackedParamsBase object at 0x564c572bd980> but it is not part of the active trace. Modules that are called during a trace must be registered as submodules of the thing being traced.
1 Like

Hi @zetyquickly,

First, your model does not run with the given inputs. The quantized model expects a quantized input, but inputs in your script is float-valued. QuantWrapper can be used to force quantization/dequantization for inputs/outputs of the model, respectively:

@@ -31,7 +31,7 @@ def fuse_modules(module):
 def create_and_update_model():
     import network
     network.ConvModel = ConvModel_decorate(network.ConvModel)
-    model = network.ConvModel()
+    model = torch.quantization.QuantWrapper(network.ConvModel())

Second, there’s a strange difference in behavior here between when ONNX is tracing the model and when we use the standalone TorchScript tracer. Tracing the model works fine when we use the standalone tracer. To workaround this issue, you can do this:

@@ -54,16 +54,19 @@ if __name__ == '__main__':
     
     model.eval()
     inputs = torch.rand(1, 3, 6, 6)
+    traced = torch.jit.trace(model, (inputs,))
+
     # Export the model to ONNX
     with torch.no_grad():
         with io.BytesIO() as f:
             torch.onnx.export(
-                model,
+                traced,
                 inputs,
                 f,
                 opset_version=11,
                 operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
                 verbose=True,  # NOTE: uncomment this for debugging
                 export_params=True,
+                example_outputs=traced(inputs)
             )
             onnx_model = onnx.load_from_string(f.getvalue())

We will investigate this difference in tracing

1 Like

Thank you @James_Reed

I had knew that we should trace a model before passing it to ONNX export, but.
For now I know it for sure

Could you please help me reveal what’s going on with traced model during exporting when I see the following:

Traceback (most recent call last):
  File "./tools/caffe2_converter.py", line 115, in <module>
    caffe2_model = export_caffe2_model(cfg, model, first_batch)
  File "/root/some_detectron2/detectron2/export/api.py", line 157, in export_caffe2_model
    return Caffe2Tracer(cfg, model, inputs).export_caffe2()
  File "/root/some_detectron2/detectron2/export/api.py", line 95, in export_caffe2
    predict_net, init_net = export_caffe2_detection_model(model, inputs)
  File "/root/some_detectron2/detectron2/export/caffe2_export.py", line 147, in export_caffe2_detection_model
    onnx_model = export_onnx_model(model, (tensor_inputs,))
  File "/root/some_detectron2/detectron2/export/caffe2_export.py", line 66, in export_onnx_model
    example_outputs=traced(inputs[0])
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 172, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 530, in _export
    fixed_batch_size=fixed_batch_size)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 384, in _model_to_graph
    fixed_batch_size=fixed_batch_size, params_dict=params_dict)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 171, in _optimize_graph
    torch._C._jit_pass_onnx_unpack_quantized_weights(graph, params_dict)
RuntimeError: quantized::conv2d_relu expected scale to be 7th input

How could it be that layer has lost its parameters?

we did some refactor in this PR: [quant] ConvPackedParams with TorchBind by jerryzh168 · Pull Request #35923 · pytorch/pytorch · GitHub that removed some arguments from quantized::conv2d related ops. Does it work for quantized::conv2d?

@jerryzh168 thankyou, this allowed us to take a step forward!

Changed fusing configuration.
Now all QuantizedConvReLU2d to QuantizedConv2d + QuantizedReLU. Don’t know whether it’s work but it produces a graph but it’s inconsistent. It causes an error that I’ve seen already.
Something wrong with produced ONNX graph.

Traceback (most recent call last):
  File "./tools/caffe2_converter.py", line 115, in <module>
    caffe2_model = export_caffe2_model(cfg, model, first_batch)
  File "/root/some_detectron2/detectron2/export/api.py", line 157, in export_caffe2_model
    return Caffe2Tracer(cfg, model, inputs).export_caffe2()
  File "/root/some_detectron2/detectron2/export/api.py", line 95, in export_caffe2
    predict_net, init_net = export_caffe2_detection_model(model, inputs)
  File "/root/some_detectron2/detectron2/export/caffe2_export.py", line 147, in export_caffe2_detection_model
    onnx_model = export_onnx_model(model, (tensor_inputs,))
  File "/root/some_detectron2/detectron2/export/caffe2_export.py", line 66, in export_onnx_model
    example_outputs=traced(inputs[0])
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/__init__.py", line 172, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 92, in export
    use_external_data_format=use_external_data_format)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/onnx/utils.py", line 557, in _export
    _check_onnx_proto(proto)
RuntimeError: Attribute 'kernel_shape' is expected to have field 'ints'

==> Context: Bad node spec: input: "735" input: "98" input: "99" output: "743" op_type: "Conv" attribute { name: "dilations" ints: 1 ints: 1 type: INTS } attribute { name: "group" i: 1 type: INT } attribute { na
me: "kernel_shape" type: INTS } attribute { name: "pads" ints: 1 ints: 1 ints: 1 ints: 1 type: INTS } attribute { name: "strides" ints: 1 ints: 1 type: INTS }

This is very location where quantized output is dequantized and fed into Conv of RPN.
Here are the bits of a graph output:

...
%98 : Long(1:1),
%99 : Long(1:1),
...
%620 : QUInt8(1:1638400, 64:25600, 128:200, 200:1) = _caffe2::Int8Relu[Y_scale=0.045047003775835037, Y_zero_point=119](%619), scope: __module._wrapped_model.backbone/__module._wrapped_model.backbone.p2_out/__module._wrapped_model.backbone.p2_out.2 # /root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/quantized/functional.py:381:0
...
%735 : Float(1:1638400, 64:25600, 128:200, 200:1) = _caffe2::Int8Dequantize(%620), scope: __module._wrapped_model.backbone/__module._wrapped_model.backbone.dequant_out # /root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/quantized/modules/__init__.py:74:0
...
%743 : Float(1:1638400, 64:25600, 128:200, 200:1) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=annotate(List[int], []), pads=[1, 1, 1, 1], strides=[1, 1]](%735, %98, %99), scope: __module._wrapped_model.proposal_generator/__module._wrapped_model.proposal_generator.rpn_head/__module._wrapped_model.proposal_generator.rpn_head.conv # /root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/conv.py:374:0

Important that eager mode model without quantization smoothly passes through this convertion pipeline and it is not data dependent.
If you are interested this is detectron2 export to Caffe2 pipeline

Please re-try with pytorch nightly build, we recently fixed this so you shouldn’t be seeing this error anymore.

Seems like the conv layer is not quantized so it produces onnx::Conv as opposed to the _caffe2::Int8Conv operator. Currently the onnx export path to caffe2 does not support partially quantized model, so it expects the entire pytorch model to be able to get quantized.

Thank you very much @supriyar,

I am still eager to find a solution. I’ve tried your assumptions, installed fresh build and tried again. Re-run QAT on model (just to make sure) and exporting process.

Now it says that MaxPool cannot be created.

Traceback (most recent call last):
  File "./tools/caffe2_converter.py", line 114, in <module>
    caffe2_model = export_caffe2_model(cfg, model, first_batch)
  File "/root/some_detectron2/detectron2/export/api.py", line 157, in export_caffe2_model
    return Caffe2Tracer(cfg, model, inputs).export_caffe2()
  File "/root/some_detectron2/detectron2/export/api.py", line 95, in export_caffe2
    predict_net, init_net = export_caffe2_detection_model(model, inputs)
  File "/root/some_detectron2/detectron2/export/caffe2_export.py", line 151, in export_caffe2_detection_model
    onnx_model = export_onnx_model(model, (tensor_inputs,))
  File "/root/some_detectron2/detectron2/export/caffe2_export.py", line 53, in export_onnx_model
    traced = torch.jit.trace(model, inputs, strict=False)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/jit/__init__.py", line 900, in trace
    check_tolerance, strict, _force_outplace, _module_class)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/jit/__init__.py", line 1054, in trace_module
    module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, strict, _force_outplace)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 575, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 561, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/contextlib.py", line 74, in inner
    return func(*args, **kwds)
  File "/root/some_detectron2/detectron2/export/caffe2_modeling.py", line 319, in forward
    features = self._wrapped_model.backbone(images.tensor)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 575, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 561, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/DensePose_ADASE/densepose/modeling/quantize.py", line 205, in new_forward
    return {"p2": p2_out, "p3": p3_out, "p4": p4_out, "p5": p5_out, "p6": self.top_block(p5_out)[0]}
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 575, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py", line 561, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/some_detectron2/detectron2/modeling/backbone/fpn.py", line 177, in forward
    return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/_jit_internal.py", line 210, in fn
    return if_false(*args, **kwargs)
  File "/root/anaconda2/envs/pytorch-gpu/lib/python3.7/site-packages/torch/nn/functional.py", line 576, in _max_pool2d
    input, kernel_size, stride, padding, dilation, ceil_mode)
RuntimeError: createStatus == pytorch_qnnp_status_success INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1590649859799/work/aten/src/ATen/native/quantized/cpu/qpool.cpp":313, please report a bug to PyTorch. failed to create QNNPACK MaxPool operator

Path to the source

Looks weird, why didn’t it happen earlier?

UPD: It looks like nested F.max_pool2d won’t quantize. Test showed that it works with float32 after convert

Is it possible to find a workaround for now? Do I understand correctly, that it is impossible to have a network with quantize dequantize during inference in Caffe2 export?

UPD: what if we just make all Convs are quantized for ONNX.export not fail

If all convs in the network are quantized it should work and you will see _caffe2::Int8Conv ops in the converted network.

You would have quantize dequantize at the start and end of the network. Which implies all the network ops are quantized.

Thank you @supriyar , I see

I have another question about non-quant operations in network

What do you think if we register C10 export like it is done here
would it be possible to patch non-quantized operators from torch.nn.ConvTranspose2d to torch.ops._caffe2.ConvTranspose2d to use them as is? Or is it better to implement quantized version of nn.ConvTranspose2d?

@jerryzh168 I could implement a PR of such functionality if it is valid

I think we are already working on quantized version of conv2d transpose, cc @Zafar

Hello, can you now export the quantized model to Caffe2, and then export Caffe2 to ncnn? Thank you!

Hello, @blueskywwc

First of all I haven’t managed to export quantized network to Caffe2. I do not know is it possible to export it to ncnn. Just a suggestion, maybe it is better to export model to ONNX and than to ncnn

hello,@jerryzh168 When will it be possible to support the conversion of the quantified model to onnx, I hope there is a reference time, thank you!

we are not working on onnx conversions, feel free to submit PRs to add the support.