Could not run 'aten::quantize_per_tensor' with arguments from the 'QuantizedCPU' backend

Hello everyone :smile:
Currently, I have a model trained on Pytorch. Its size is around 42 Mb. The expected inputs of this model are (1, 3, 512, 512) images. This model should be deployed on an iOS mobile app but first it needs optimization. I found out about Eager Mode Quantization as a method used in Pytorch so I am using post-training static quantization to optimize my model.

Once the quantization applied to this model, I obtain a quantized model that is around 11Mb and I can se that my layers have been quantized such as follows:

      (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
      (conv1): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn1): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dequant): DeQuantize()

However, when I try to make an inference on a random image, I obtain this error NotImplementedError: Could not run 'aten::quantize_per_tensor' with arguments from the 'QuantizedCPU' backend.

Here is the code when I do the inference on a random image:

from quant_architecture import *
import torch.nn.functional as F
import numpy as np


# --- Test quantized model on input
model_q = ResNet(resnet18_config, 2)
model_q.eval()
model_q.to("cpu")

backend = "fbgemm"  # x86 machine
model_q.qconfig = torch.quantization.get_default_qconfig(backend)
model_static_prepared = torch.quantization.prepare(model_q, inplace=False)
model_static_quantized = torch.quantization.convert(
    model_static_prepared, inplace=False
)

model_static_quantized.load_state_dict(torch.load("focus_quantized_test.pt"))
model_static_quantized.to("cpu")

# # # # --- Test output on quantized model
np.random.seed(44)
dummy_input = torch.rand(1, 3, 512, 512).to("cpu", dtype=torch.float)  # Corresponds to a 512*512 RGB image
test_output, hidden = model_static_quantized(dummy_input) 
unscripted_top2 = F.softmax(test_output, dim=1).topk(2).indices
print('Python model top 2 results:\n  {}'.format(unscripted_top2))

Here is the quantization code applied to my original model:

from quant_architecture import *
import os

# --- Create architecture focus model with Quant()/DeQuant()
# Create a model instance
model = ResNet(resnet18_config, 2)
# Load weights from trained model
model.load_state_dict(torch.load("focus_weights.pt"))
print("%.2f MB" % (os.path.getsize("focus_weights.pt") / 1e6))
# Display model
print("Initial model looks like:\n", model)


# --- Quantization of initial model with Quant()/DeQuant() architecture
model.eval()  # useful for calibration
backend = "fbgemm"  # x86 machine
model.qconfig = torch.quantization.get_default_qconfig(backend)
# print(model.qconfig)
model_static_prepared = torch.quantization.prepare(model, inplace=False)
model_static_quantized = torch.quantization.convert(
    model_static_prepared, inplace=False
)
# Display quantized model
print("Quantized model looks like:\n", model_static_quantized)
# --- Save quantized model before C++ interface
torch.save(
    model_static_quantized.state_dict(), "focus_quantized_test.pt"
)  # Save unscripted quantized model - test
# Display size of quantized model
print("%.2f MB" % (os.path.getsize("focus_quantized_test.pt") / 1e6))

Additional info:
In my script quant_architecture, I added self.quant = torch.quantization.QuantStub() & self.dequant = torch.quantization.DeQuantStub() around my original model. Also, I added x = self.quant(x) & x = self.dequant(x) in the froward function.
Besides, I added input.to("cpu", dtype=torch.float) and model_static_quantized.to("cpu") to make sure that the inference is running on CPU and not QuantizedCPU.

Seems that issue comes from around here: pytorch/native_functions.yaml at 6cbe9d1f58fdc9288833b8d82db6896af0e4555f · pytorch/pytorch · GitHub

Any suggestions, ideas or questions would help a lot!
Thanks :slight_smile:

Can you also post the whole error message?

The error that you described can be caused by the Quantize module, if the input is already quantized, as it expects the inputs to be float. The error logs would either confirm that, or point to any other module that is trying to quantize.

Hi @Zafar
Thanks a lot for your reply.
If I skip the quant() step, I have another error: Could not run ‘quantized::conv2d.new’ with arguments from the ‘CPU’ backend similar to Could not run 'quantized::conv2d.new' with arguments from the 'QuantizedCUDA' backend - #6 by HDCharles

Here is the error message I got:

Traceback (most recent call last):
  File "C:\Users\sarra\Documents\github\Machine-Learning\FocusClassification\quant_inference.py", line 32, in <module>
    test_output, hidden = model_static_quantized(dummy_input)
  File "C:\Users\sarra\Documents\github\Machine-Learning\venv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\sarra\Documents\github\Machine-Learning\FocusClassification\quant_architecture.py", line 73, in forward
    x = self.layer1(x)
  File "C:\Users\sarra\Documents\github\Machine-Learning\venv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\sarra\Documents\github\Machine-Learning\venv\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
    input = module(input)
  File "C:\Users\sarra\Documents\github\Machine-Learning\venv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\sarra\Documents\github\Machine-Learning\FocusClassification\quant_architecture.py", line 133, in forward
    x = self.quant(x)
  File "C:\Users\sarra\Documents\github\Machine-Learning\venv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\sarra\Documents\github\Machine-Learning\venv\lib\site-packages\torch\nn\quantized\modules\__init__.py", line 52, in forward
    return torch.quantize_per_tensor(X, float(self.scale),

NotImplementedError: Could not run 'aten::quantize_per_tensor' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::quantize_per_tensor' is only available for these backends: [CPU, CUDA, BackendSelect, Python, Named, Conjugate, Negative, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, UNKNOWN_TENSOR_TYPE_ID, Autocast, Batched, VmapMode].

CPU: registered at aten\src\ATen\RegisterCPU.cpp:18433 [kernel]
CUDA: registered at aten\src\ATen\RegisterCUDA.cpp:26493 [kernel]
BackendSelect: fallthrough registered at ..\aten\src\ATen\core\BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ..\aten\src\ATen\core\PythonFallbackKernel.cpp:47 [backend fallback]
Named: registered at ..\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ..\aten\src\ATen\ConjugateFallback.cpp:18 [backend fallback]
Negative: registered at ..\aten\src\ATen\native\NegateFallback.cpp:18 [backend fallback]
ADInplaceOrView: fallthrough registered at ..\aten\src\ATen\core\VariableFallbackKernel.cpp:64 [backend fallback]
AutogradOther: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradCPU: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradCUDA: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradXLA: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradLazy: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradXPU: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradMLC: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradHPU: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradNestedTensor: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradPrivateUse1: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradPrivateUse2: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
AutogradPrivateUse3: registered at ..\torch\csrc\autograd\generated\VariableType_2.cpp:10483 [autograd kernel]
Tracer: registered at ..\torch\csrc\autograd\generated\TraceType_2.cpp:11423 [kernel]
UNKNOWN_TENSOR_TYPE_ID: fallthrough registered at ..\aten\src\ATen\autocast_mode.cpp:466 [backend fallback]
Autocast: fallthrough registered at ..\aten\src\ATen\autocast_mode.cpp:305 [backend fallback]
Batched: registered at ..\aten\src\ATen\BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at ..\aten\src\ATen\VmapModeRegistrations.cpp:33 [backend fallback]

can you print the full quantized model? looks like there might be redundant QuantStub in layer1

1 Like

Hi @jerryzh168
Thanks for your answer!
Here is the full quantized model

ResNet(
  (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=1.0, zero_point=0, padding=(3, 3), bias=False)
  (bn1): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
      (conv1): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn1): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dequant): DeQuantize()
    )
    (1): BasicBlock(
      (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
      (conv1): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn1): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dequant): DeQuantize()
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
      (conv1): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn1): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): QuantizedConv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), scale=1.0, zero_point=0, bias=False)
        (1): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (dequant): DeQuantize()
    )
    (1): BasicBlock(
      (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
      (conv1): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn1): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dequant): DeQuantize()
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
      (conv1): QuantizedConv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn1): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): QuantizedConv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), scale=1.0, zero_point=0, bias=False)
        (1): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (dequant): DeQuantize()
    )
    (1): BasicBlock(
      (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
      (conv1): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn1): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dequant): DeQuantize()
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
      (conv1): QuantizedConv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn1): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): QuantizedConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), scale=1.0, zero_point=0, bias=False)
        (1): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (dequant): DeQuantize()
    )
    (1): BasicBlock(
      (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
      (conv1): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn1): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1), bias=False)
      (bn2): QuantizedBatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (dequant): DeQuantize()
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): QuantizedLinear(in_features=512, out_features=2, scale=1.0, zero_point=0, qscheme=torch.per_channel_affine)
  (dequant): DeQuantize()
)

You can find some additional inputs on my issue in here

from this model it looks like there is a redundant QuantStub for each BasicBlock