RuntimeError: Could not run 'aten::add.Tensor' with arguments from the 'QuantizedCPU' backend. 'aten::add.Tensor' is only available for these backends: [CPU,

I couldn’t solve this problem, so I moved it here. Sorry about that.

I followed tutorials/quantization and tried to PTQ MobileNetV2 from torchvision.
However, when I tried to predict with the quantized model, I got the following error and could not run it.
How can I solve this problem?

Error

Traceback (most recent call last):
  File "ptq_imagenet_pth.py", line 184, in <module>
    res = model_static_quantized(x.clone().detach().to(device, dtype=torch.float))

....

RuntimeError: Could not run 'aten::add.Tensor' with arguments from the 'QuantizedCPU' backend. 'aten::add.Tensor' is only available for these backends: [CPU, CUDA, MkldnnCPU, SparseCPU, SparseCUDA, Meta, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /pytorch/build/aten/src/ATen/CPUType.cpp:2127 [kernel]
CUDA: registered at /pytorch/build/aten/src/ATen/CUDAType.cpp:2983 [kernel]
MkldnnCPU: registered at /pytorch/build/aten/src/ATen/MkldnnCPUType.cpp:144 [kernel]
SparseCPU: registered at /pytorch/build/aten/src/ATen/SparseCPUType.cpp:239 [kernel]
SparseCUDA: registered at /pytorch/build/aten/src/ATen/SparseCUDAType.cpp:320 [kernel]
Meta: registered at /pytorch/aten/src/ATen/native/BinaryOps.cpp:1049 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: fallthrough registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:11 [kernel]
AutogradOther: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCUDA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradXLA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse1: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse2: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse3: registered at /pytorch/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
Tracer: registered at /pytorch/torch/csrc/autograd/generated/TraceType_2.cpp:9654 [kernel]
Autocast: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:515 [kernel]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

Script

class QuantizedMobileNetV2(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedMobileNetV2, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.model_fp32 = model_fp32

    def forward(self, x):
        x = self.quant(x)
        x = self.model_fp32(x)
        x = self.dequant(x)
        return x


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.mobilenet_v2(pretrained=True)

model.eval()
model = QuantizedMobileNetV2(model_fp32=model)
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(model, inplace=False)
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
model_static_quantized = model_static_quantized.to(device)
#x is input tensor whose shape is (100, 3, 224, 224) 
res = model_static_quantized(x.clone().detach().to(device, dtype=torch.float))

I got same error when I used QuantWrapper().

Environment

  • Ubuntu: 18.0
  • CUDA: 11.0
  • Python: 3.6.10
  • PyTorch: 1.7.1
  • torchvision: 0.8.2

Hi @crook52 , with Eager mode quantization the user needs to place quants and dequants at every place in the model where tensors need to convert from fp32 to int8, and vice versa. If it’s not working for you if you wrap the model with quant/dequant, that likely means that there are places inside your model where the same thing needs to be done. This is often a pretty tedious process.

For MobileNetV2, we have a quantizeable model with all the quants/dequants/fusions done here (vision/mobilenetv2.py at master · pytorch/vision · GitHub), so you are welcome to use that if that works for your use case.

1 Like

Hi @Vasiliy_Kuznetsov, sorry for my late reply and thank you for your answer!

I understood why I can’t convert mobilenetv2.
I was able to convert the MobileNetV2 which you referred me to.
Also, based on the implementation of a quantizeable model, I modified the model of original MobileNetv2 as follows and was able to convert it!!!

    def forward(self, x):
        if self.use_res_connect:
            # return x + self.conv(x)
            return self.skip_add.add(x, self.conv(x))
        else:
            return self.conv(x)

Unfortunately, the accuracy of both methods is very bad.
However, thanks to you, I was able to quantize it.
Thank you very much!!!