Error in running quantised model RuntimeError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend

Hello
I am trying to run model quantization like in official tutorial Quantization Recipe β€” PyTorch Tutorials 1.11.0+cu102 documentation
I had no problem with training and saving model, but running it with jit in C++ and in Python throws same error about not implemented error. Any ideas?

import torch
import torchvision
model = torchvision.models.mobilenet_v2()
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(model, inplace=False)
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
torch.jit.save(torch.jit.script(model_static_quantized), '/data/Igor/projects/torch-cpp/tutorial.pt')

module = torch.jit.load('/data/Igor/projects/torch-cpp/tutorial.pt')

inputs = torch.ones((1,3,224,224))

module(inputs)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-37-21f6333d0b06> in <module>
     13 inputs = torch.ones((1,3,224,224))
     14 
---> 15 module(inputs)

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torchvision/models/mobilenetv2.py", line 11, in forward
  def forward(self: __torch__.torchvision.models.mobilenetv2.MobileNetV2,
    x: Tensor) -> Tensor:
    return (self)._forward_impl(x, )
            ~~~~~~~~~~~~~~~~~~~ <--- HERE
  def _forward_impl(self: __torch__.torchvision.models.mobilenetv2.MobileNetV2,
    x: Tensor) -> Tensor:
  File "code/__torch__/torchvision/models/mobilenetv2.py", line 15, in _forward_impl
    x: Tensor) -> Tensor:
    _0 = __torch__.torch.nn.functional.adaptive_avg_pool2d
    x0 = (self.features).forward(x, )
          ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    x1 = torch.reshape(_0(x0, [1, 1], ), [(torch.size(x0))[0], -1])
    return (self.classifier).forward(x1, )
  File "code/__torch__/torch/nn/modules/container/___torch_mangle_255.py", line 46, in forward
    _17 = getattr(self, "17")
    _18 = getattr(self, "18")
    input0 = (_0).forward(input, )
              ~~~~~~~~~~~ <--- HERE
    input1 = (_1).forward(input0, )
    input2 = (_2).forward(input1, )
  File "code/__torch__/torchvision/models/mobilenetv2/___torch_mangle_217.py", line 15, in forward
    _1 = getattr(self, "1")
    _2 = getattr(self, "2")
    input0 = (_0).forward(input, )
              ~~~~~~~~~~~ <--- HERE
    input1 = (_1).forward(input0, )
    return (_2).forward(input1, )
  File "code/__torch__/torch/nn/quantized/modules/conv.py", line 36, in forward
    else:
      input0 = input
    _6 = ops.quantized.conv2d(input0, self._packed_params, self.scale, self.zero_point)
         ~~~~~~~~~~~~~~~~~~~~ <--- HERE
    return _6
  def __getstate__(self: __torch__.torch.nn.quantized.modules.conv.Conv2d) -> Tuple[int, int, Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int], bool, Tuple[int, int], int, str, Tensor, Optional[Tensor], float, int, bool]:

Traceback of TorchScript, original code (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torchvision/models/mobilenetv2.py", line 198, in forward
    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)
               ~~~~~~~~~~~~~~~~~~ <--- HERE
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 118, in forward
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 118, in forward
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py", line 407, in forward
            input = F.pad(input, _reversed_padding_repeated_twice,
                          mode=self.padding_mode)
        return ops.quantized.conv2d(
               ~~~~~~~~~~~~~~~~~~~~ <--- HERE
            input, self._packed_params, self.scale, self.zero_point)
RuntimeError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv2d.new' is only available for these backends: [QuantizedCPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, Tracer, Autocast, Batched, VmapMode].

QuantizedCPU: registered at ../aten/src/ATen/native/quantized/cpu/qconv.cpp:873 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:35 [backend fallback]
AutogradCPU: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:39 [backend fallback]
AutogradCUDA: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:43 [backend fallback]
AutogradXLA: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:47 [backend fallback]
Tracer: fallthrough registered at ../torch/csrc/jit/frontend/tracer.cpp:1019 [backend fallback]
Autocast: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:250 [backend fallback]
Batched: registered at ../aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
2 Likes

It means that you cannot pass fp32 to your quantized model but would have to quantize the input.
If you add QuantStub called at the beginning and DeQuantStub for the end, you can pass fp32.

Best regards

Thomas

1 Like

Thank you for you reply!
I had to follow blogpost to see how to fuse model and now it works.

import torch
import torchvision

class QuantizedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model_fp32 = model
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x
        
def fuse_resnet18(model):
    torch.quantization.fuse_modules(model, [["conv1", "bn1", "relu"]], inplace=True)
    for module_name, module in model.named_children():
        if "layer" in module_name:
            for basic_block_name, basic_block in module.named_children():
                torch.quantization.fuse_modules(basic_block, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
                for sub_block_name, sub_block in basic_block.named_children():
                    if sub_block_name == "downsample":
                        torch.quantization.fuse_modules(sub_block, [["0", "1"]], inplace=True)    

model = torchvision.models.resnet18()
fuse_resnet18(model)
quantized_model = QuantizedModel(model)

backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(quantized_model, inplace=False)
model_static_quantized = torch.quantization.convert(model, inplace=False)
torch.jit.save(torch.jit.script(model_static_quantized), '/data/Igor/projects/torch-cpp/tutorial.pt')

module = torch.jit.load('/data/Igor/projects/torch-cpp/tutorial.pt')

inputs = torch.ones((1,3,224,224))

module(inputs)
1 Like

Surely you meant:
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)

?

Most likely yes :slight_smile: Thanks that 11 month later I got the reply. Don’t think that this is relevant for me as of now, but good to know that it works

1 Like